55To add a new kernel:
661. Implement the kernel in the all namespace (e.g., AVX10_2, AVX512, DEFAULT). See existing kernel files for reference.
77 Note: Kernel files must be named in the format of <kernel_name>_krnl.cpp (e.g., da8w4_linear_krnl.cpp).
8- 2. Declare the kernel function using the corresponding macro (e.g., declare_da8w4_linear_impl) in the same namespace.
9- 3. Add a call macro (e.g., call_da8w4_linear_impl) in the same namespace that calls the implemented kernel function.
10- 4. Add a dispatch function outside ISA-related namespace, which calls the appropriate kernel based on the available hardware features.
8+ 2. Declare the kernel function type as <kernel_name>_fn.
9+ 3. Add an entry in the KernelDispatcher struct for the new kernel.
10+ 4. Add a declaration of the kernel function in all namespaces.
11+ 5. Add an entry in the get_kernel_dispatcher function.
12+ 6. Add a wrapper that calls kernel that the dispatcher selects at runtime.
13+ 7. Register the python op with the wrapper.
1114*/
1215namespace torchao {
1316
17+ /* ********* Lightweight ISA-based Dispatcher **********/
18+ // Function pointer types for each kernel
19+ using da8w4_linear_prepack_fn = std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(*)(
20+ const at::Tensor&, const at::Tensor&, const at::Tensor&);
21+
22+ using da8w4_linear_fn = at::Tensor(*)(
23+ const at::Tensor&, const at::Tensor&, const at::Tensor&,
24+ const at::Tensor&, const at::Tensor&, const at::Tensor&,
25+ const at::Tensor&, const std::optional<at::Tensor>&, at::ScalarType);
26+
27+ using float8_linear_prepack_fn = std::tuple<at::Tensor, at::Tensor>(*)(
28+ const at::Tensor&, const at::Tensor&);
29+
30+ using float8_linear_fn = at::Tensor(*)(
31+ const at::Tensor&, const at::Tensor&,
32+ const at::Tensor&, const at::Tensor&,
33+ const std::optional<at::Tensor>&, at::ScalarType);
34+
35+ using scaled_embedding_bag_fn = at::Tensor(*)(
36+ const at::Tensor&, const at::Tensor&, const at::Tensor&,
37+ const at::Tensor&, double , int64_t , bool , at::ScalarType);
38+
39+ using qscaled_dot_product_fn = at::Tensor(*)(
40+ const at::Tensor&, const at::Tensor&, const at::Tensor&,
41+ std::optional<at::Tensor>, double , bool , std::optional<double >,
42+ double , int64_t , double , int64_t , double , int64_t ,
43+ double , int64_t , double , int64_t );
44+
45+ // Dispatcher table: holds function pointers for all kernels
46+ struct KernelDispatcher {
47+ da8w4_linear_prepack_fn da8w4_linear_prepack;
48+ da8w4_linear_fn da8w4_linear;
49+ float8_linear_prepack_fn float8_linear_prepack;
50+ float8_linear_fn float8_linear;
51+ scaled_embedding_bag_fn scaled_embedding_bag;
52+ qscaled_dot_product_fn qscaled_dot_product;
53+ };
54+
1455/* ********* DA8W4 Linear Kernel Declare **********/
1556#define declare_da8w4_linear_prepack_impl \
1657 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> \
@@ -31,21 +72,6 @@ namespace torchao {
3172 const std::optional<at::Tensor>& bias, \
3273 at::ScalarType output_dtype)
3374
34- #define call_da8w4_linear_prepack_impl () \
35- da8w4_linear_prepack_impl (weight, scales, qzeros)
36-
37- #define call_da8w4_linear_impl () \
38- da8w4_linear_impl ( \
39- input, \
40- input_scales, \
41- input_qzeros, \
42- weight, \
43- weight_scales, \
44- weight_qzeros, \
45- compensation, \
46- bias, \
47- output_dtype)
48-
4975/* ********* FLOAT8 Linear Kernel Declare **********/
5076#define declare_float8_linear_prepack_impl \
5177 std::tuple<at::Tensor, at::Tensor> \
@@ -62,18 +88,6 @@ namespace torchao {
6288 const std::optional<at::Tensor>& bias, \
6389 at::ScalarType output_dtype)
6490
65- #define call_float8_linear_prepack_impl () \
66- float8_linear_prepack_impl (weight, scales)
67-
68- #define call_float8_linear_impl () \
69- float8_linear_impl ( \
70- input, \
71- input_scales, \
72- weight, \
73- weight_scales, \
74- bias, \
75- output_dtype)
76-
7791/* ********* Scaled Embedding Bag Kernel Declare **********/
7892#define declare_scaled_embedding_bag_impl \
7993 at::Tensor _scaled_embedding_bag_impl ( \
@@ -86,17 +100,6 @@ namespace torchao {
86100 bool include_last_offset, \
87101 at::ScalarType output_dtype)
88102
89- #define call_scaled_embedding_bag_impl () \
90- _scaled_embedding_bag_impl ( \
91- qweight, \
92- indices, \
93- offsets, \
94- w_scales, \
95- o_scale, \
96- mode, \
97- include_last_offset, \
98- output_dtype)
99-
100103/* ********* Quantized SDPA Kernel Declare **********/
101104#define declare_qscaled_dot_product_impl \
102105 at::Tensor _qscaled_dot_product_cpu ( \
@@ -118,26 +121,6 @@ namespace torchao {
118121 double o_scale, \
119122 int64_t o_zp)
120123
121- #define call_qscaled_dot_product_impl () \
122- _qscaled_dot_product_cpu ( \
123- query, \
124- key, \
125- value, \
126- attn_mask, \
127- dropout_p, \
128- is_causal, \
129- scale, \
130- q_scale, \
131- q_zp, \
132- k_scale, \
133- k_zp, \
134- v_scale, \
135- v_zp, \
136- a_scale, \
137- a_zp, \
138- o_scale, \
139- o_zp)
140-
141124/* ********* Declare All Kernels in All Namespaces **********/
142125#define declare_all_kernels (namespace_name ) \
143126 namespace namespace_name { \
@@ -153,93 +136,75 @@ declare_all_kernels(AVX10_2)
153136declare_all_kernels(AVX512)
154137declare_all_kernels(DEFAULT)
155138
156- /* ********* DA8W4 Linear Kernel Dispatch **********/
157- declare_da8w4_linear_prepack_impl {
158- // BUILD_AVX10_2 is only defined when the compiler passed the AVX10.2 ISA probe in setup.py.
139+ /* ********* Dispatcher Selection and Dispatch Functions **********/
140+ // Select the appropriate dispatcher based on runtime ISA capabilities
141+ KernelDispatcher& get_kernel_dispatcher() {
142+ static KernelDispatcher dispatcher = []() {
143+ KernelDispatcher d;
144+ // Select ISA level based on runtime detection (kHas*) and compile-time checks (BUILD_*)
159145#if defined(BUILD_AVX10_2)
160- if (kHasAVX10_2 ) {
161- return AVX10_2::call_da8w4_linear_prepack_impl ();
162- }
146+ if (kHasAVX10_2 ) {
147+ d = {AVX10_2::da8w4_linear_prepack_impl,
148+ AVX10_2::da8w4_linear_impl,
149+ AVX10_2::float8_linear_prepack_impl,
150+ AVX10_2::float8_linear_impl,
151+ AVX10_2::_scaled_embedding_bag_impl,
152+ AVX10_2::_qscaled_dot_product_cpu};
153+ return d;
154+ }
163155#endif
164156#if defined(BUILD_AVX512)
165- if (kHasAVX512 ) {
166- return AVX512::call_da8w4_linear_prepack_impl ();
167- }
157+ if (kHasAVX512 ) {
158+ d = {AVX512::da8w4_linear_prepack_impl,
159+ AVX512::da8w4_linear_impl,
160+ AVX512::float8_linear_prepack_impl,
161+ AVX512::float8_linear_impl,
162+ AVX512::_scaled_embedding_bag_impl,
163+ AVX512::_qscaled_dot_product_cpu};
164+ return d;
165+ }
168166#endif
169- return DEFAULT::call_da8w4_linear_prepack_impl ();
167+ // Fall back to DEFAULT (always available)
168+ d = {DEFAULT::da8w4_linear_prepack_impl,
169+ DEFAULT::da8w4_linear_impl,
170+ DEFAULT::float8_linear_prepack_impl,
171+ DEFAULT::float8_linear_impl,
172+ DEFAULT::_scaled_embedding_bag_impl,
173+ DEFAULT::_qscaled_dot_product_cpu};
174+ return d;
175+ }();
176+ return dispatcher;
177+ }
178+
179+ declare_da8w4_linear_prepack_impl {
180+ return get_kernel_dispatcher ().da8w4_linear_prepack (weight, scales, qzeros);
170181}
171182
172183declare_da8w4_linear_impl {
173- #if defined(BUILD_AVX10_2)
174- if (kHasAVX10_2 ) {
175- return AVX10_2::call_da8w4_linear_impl ();
176- }
177- #endif
178- #if defined(BUILD_AVX512)
179- if (kHasAVX512 ) {
180- return AVX512::call_da8w4_linear_impl ();
181- }
182- #endif
183- return DEFAULT::call_da8w4_linear_impl ();
184+ return get_kernel_dispatcher ().da8w4_linear (
185+ input, input_scales, input_qzeros, weight, weight_scales, weight_qzeros,
186+ compensation, bias, output_dtype);
184187}
185188
186- /* ********* FLOAT8 Linear Kernel Dispatch **********/
187189declare_float8_linear_prepack_impl {
188- #if defined(BUILD_AVX10_2)
189- if (kHasAVX10_2 ) {
190- return AVX10_2::call_float8_linear_prepack_impl ();
191- }
192- #endif
193- #if defined(BUILD_AVX512)
194- if (kHasAVX512 ) {
195- return AVX512::call_float8_linear_prepack_impl ();
196- }
197- #endif
198- return DEFAULT::call_float8_linear_prepack_impl ();
190+ return get_kernel_dispatcher ().float8_linear_prepack (weight, scales);
199191}
200192
201193declare_float8_linear_impl {
202- #if defined(BUILD_AVX10_2)
203- if (kHasAVX10_2 ) {
204- return AVX10_2::call_float8_linear_impl ();
205- }
206- #endif
207- #if defined(BUILD_AVX512)
208- if (kHasAVX512 ) {
209- return AVX512::call_float8_linear_impl ();
210- }
211- #endif
212- return DEFAULT::call_float8_linear_impl ();
194+ return get_kernel_dispatcher ().float8_linear (
195+ input, input_scales, weight, weight_scales, bias, output_dtype);
213196}
214197
215- /* ********* Scaled Embedding Bag Kernel Dispatch **********/
216198declare_scaled_embedding_bag_impl {
217- #if defined(BUILD_AVX10_2)
218- if (kHasAVX10_2 ) {
219- return AVX10_2::call_scaled_embedding_bag_impl ();
220- }
221- #endif
222- #if defined(BUILD_AVX512)
223- if (kHasAVX512 ) {
224- return AVX512::call_scaled_embedding_bag_impl ();
225- }
226- #endif
227- return DEFAULT::call_scaled_embedding_bag_impl ();
199+ return get_kernel_dispatcher ().scaled_embedding_bag (
200+ qweight, indices, offsets, w_scales, o_scale, mode, include_last_offset,
201+ output_dtype);
228202}
229203
230- /* ********* Quantized SDPA Kernel **********/
231204declare_qscaled_dot_product_impl {
232- #if defined(BUILD_AVX10_2)
233- if (kHasAVX10_2 ) {
234- return AVX10_2::call_qscaled_dot_product_impl ();
235- }
236- #endif
237- #if defined(BUILD_AVX512)
238- if (kHasAVX512 ) {
239- return AVX512::call_qscaled_dot_product_impl ();
240- }
241- #endif
242- return DEFAULT::call_qscaled_dot_product_impl ();
205+ return get_kernel_dispatcher ().qscaled_dot_product (
206+ query, key, value, attn_mask, dropout_p, is_causal, scale, q_scale, q_zp,
207+ k_scale, k_zp, v_scale, v_zp, a_scale, a_zp, o_scale, o_zp);
243208}
244209
245210TORCH_LIBRARY_IMPL (torchao, CPU, m) {
0 commit comments