Skip to content

Commit fee9b02

Browse files
author
Copilot
committed
Remove unused code
1 parent 9ec0d6b commit fee9b02

2 files changed

Lines changed: 95 additions & 129 deletions

File tree

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def add_compile_flags(extra_compile_args: dict) -> None:
545545
if not X86KernelBuild.is_enabled():
546546
return
547547
flags = [
548+
"-DCPU_CAPABILITY=DEFAULT",
548549
"-fno-tree-vectorize",
549550
"-fopenmp",
550551
]

torchao/csrc/cpu/aten_kernels/dispatch.cpp

Lines changed: 94 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,53 @@
55
To add a new kernel:
66
1. Implement the kernel in the all namespace (e.g., AVX10_2, AVX512, DEFAULT). See existing kernel files for reference.
77
Note: Kernel files must be named in the format of <kernel_name>_krnl.cpp (e.g., da8w4_linear_krnl.cpp).
8-
2. Declare the kernel function using the corresponding macro (e.g., declare_da8w4_linear_impl) in the same namespace.
9-
3. Add a call macro (e.g., call_da8w4_linear_impl) in the same namespace that calls the implemented kernel function.
10-
4. Add a dispatch function outside ISA-related namespace, which calls the appropriate kernel based on the available hardware features.
8+
2. Declare the kernel function type as <kernel_name>_fn.
9+
3. Add an entry in the KernelDispatcher struct for the new kernel.
10+
4. Add a declaration of the kernel function in all namespaces.
11+
5. Add an entry in the get_kernel_dispatcher function.
12+
6. Add a wrapper that calls kernel that the dispatcher selects at runtime.
13+
7. Register the python op with the wrapper.
1114
*/
1215
namespace torchao {
1316

17+
/********** Lightweight ISA-based Dispatcher **********/
18+
// Function pointer types for each kernel
19+
using da8w4_linear_prepack_fn = std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(*)(
20+
const at::Tensor&, const at::Tensor&, const at::Tensor&);
21+
22+
using da8w4_linear_fn = at::Tensor(*)(
23+
const at::Tensor&, const at::Tensor&, const at::Tensor&,
24+
const at::Tensor&, const at::Tensor&, const at::Tensor&,
25+
const at::Tensor&, const std::optional<at::Tensor>&, at::ScalarType);
26+
27+
using float8_linear_prepack_fn = std::tuple<at::Tensor, at::Tensor>(*)(
28+
const at::Tensor&, const at::Tensor&);
29+
30+
using float8_linear_fn = at::Tensor(*)(
31+
const at::Tensor&, const at::Tensor&,
32+
const at::Tensor&, const at::Tensor&,
33+
const std::optional<at::Tensor>&, at::ScalarType);
34+
35+
using scaled_embedding_bag_fn = at::Tensor(*)(
36+
const at::Tensor&, const at::Tensor&, const at::Tensor&,
37+
const at::Tensor&, double, int64_t, bool, at::ScalarType);
38+
39+
using qscaled_dot_product_fn = at::Tensor(*)(
40+
const at::Tensor&, const at::Tensor&, const at::Tensor&,
41+
std::optional<at::Tensor>, double, bool, std::optional<double>,
42+
double, int64_t, double, int64_t, double, int64_t,
43+
double, int64_t, double, int64_t);
44+
45+
// Dispatcher table: holds function pointers for all kernels
46+
struct KernelDispatcher {
47+
da8w4_linear_prepack_fn da8w4_linear_prepack;
48+
da8w4_linear_fn da8w4_linear;
49+
float8_linear_prepack_fn float8_linear_prepack;
50+
float8_linear_fn float8_linear;
51+
scaled_embedding_bag_fn scaled_embedding_bag;
52+
qscaled_dot_product_fn qscaled_dot_product;
53+
};
54+
1455
/********** DA8W4 Linear Kernel Declare **********/
1556
#define declare_da8w4_linear_prepack_impl \
1657
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> \
@@ -31,21 +72,6 @@ namespace torchao {
3172
const std::optional<at::Tensor>& bias, \
3273
at::ScalarType output_dtype)
3374

34-
#define call_da8w4_linear_prepack_impl() \
35-
da8w4_linear_prepack_impl(weight, scales, qzeros)
36-
37-
#define call_da8w4_linear_impl() \
38-
da8w4_linear_impl( \
39-
input, \
40-
input_scales, \
41-
input_qzeros, \
42-
weight, \
43-
weight_scales, \
44-
weight_qzeros, \
45-
compensation, \
46-
bias, \
47-
output_dtype)
48-
4975
/********** FLOAT8 Linear Kernel Declare **********/
5076
#define declare_float8_linear_prepack_impl \
5177
std::tuple<at::Tensor, at::Tensor> \
@@ -62,18 +88,6 @@ namespace torchao {
6288
const std::optional<at::Tensor>& bias, \
6389
at::ScalarType output_dtype)
6490

65-
#define call_float8_linear_prepack_impl() \
66-
float8_linear_prepack_impl(weight, scales)
67-
68-
#define call_float8_linear_impl() \
69-
float8_linear_impl( \
70-
input, \
71-
input_scales, \
72-
weight, \
73-
weight_scales, \
74-
bias, \
75-
output_dtype)
76-
7791
/********** Scaled Embedding Bag Kernel Declare **********/
7892
#define declare_scaled_embedding_bag_impl \
7993
at::Tensor _scaled_embedding_bag_impl( \
@@ -86,17 +100,6 @@ namespace torchao {
86100
bool include_last_offset, \
87101
at::ScalarType output_dtype)
88102

89-
#define call_scaled_embedding_bag_impl() \
90-
_scaled_embedding_bag_impl( \
91-
qweight, \
92-
indices, \
93-
offsets, \
94-
w_scales, \
95-
o_scale, \
96-
mode, \
97-
include_last_offset, \
98-
output_dtype)
99-
100103
/********** Quantized SDPA Kernel Declare **********/
101104
#define declare_qscaled_dot_product_impl \
102105
at::Tensor _qscaled_dot_product_cpu( \
@@ -118,26 +121,6 @@ namespace torchao {
118121
double o_scale, \
119122
int64_t o_zp)
120123

121-
#define call_qscaled_dot_product_impl() \
122-
_qscaled_dot_product_cpu( \
123-
query, \
124-
key, \
125-
value, \
126-
attn_mask, \
127-
dropout_p, \
128-
is_causal, \
129-
scale, \
130-
q_scale, \
131-
q_zp, \
132-
k_scale, \
133-
k_zp, \
134-
v_scale, \
135-
v_zp, \
136-
a_scale, \
137-
a_zp, \
138-
o_scale, \
139-
o_zp)
140-
141124
/********** Declare All Kernels in All Namespaces **********/
142125
#define declare_all_kernels(namespace_name) \
143126
namespace namespace_name { \
@@ -153,93 +136,75 @@ declare_all_kernels(AVX10_2)
153136
declare_all_kernels(AVX512)
154137
declare_all_kernels(DEFAULT)
155138

156-
/********** DA8W4 Linear Kernel Dispatch **********/
157-
declare_da8w4_linear_prepack_impl {
158-
// BUILD_AVX10_2 is only defined when the compiler passed the AVX10.2 ISA probe in setup.py.
139+
/********** Dispatcher Selection and Dispatch Functions **********/
140+
// Select the appropriate dispatcher based on runtime ISA capabilities
141+
KernelDispatcher& get_kernel_dispatcher() {
142+
static KernelDispatcher dispatcher = []() {
143+
KernelDispatcher d;
144+
// Select ISA level based on runtime detection (kHas*) and compile-time checks (BUILD_*)
159145
#if defined(BUILD_AVX10_2)
160-
if (kHasAVX10_2) {
161-
return AVX10_2::call_da8w4_linear_prepack_impl();
162-
}
146+
if (kHasAVX10_2) {
147+
d = {AVX10_2::da8w4_linear_prepack_impl,
148+
AVX10_2::da8w4_linear_impl,
149+
AVX10_2::float8_linear_prepack_impl,
150+
AVX10_2::float8_linear_impl,
151+
AVX10_2::_scaled_embedding_bag_impl,
152+
AVX10_2::_qscaled_dot_product_cpu};
153+
return d;
154+
}
163155
#endif
164156
#if defined(BUILD_AVX512)
165-
if (kHasAVX512) {
166-
return AVX512::call_da8w4_linear_prepack_impl();
167-
}
157+
if (kHasAVX512) {
158+
d = {AVX512::da8w4_linear_prepack_impl,
159+
AVX512::da8w4_linear_impl,
160+
AVX512::float8_linear_prepack_impl,
161+
AVX512::float8_linear_impl,
162+
AVX512::_scaled_embedding_bag_impl,
163+
AVX512::_qscaled_dot_product_cpu};
164+
return d;
165+
}
168166
#endif
169-
return DEFAULT::call_da8w4_linear_prepack_impl();
167+
// Fall back to DEFAULT (always available)
168+
d = {DEFAULT::da8w4_linear_prepack_impl,
169+
DEFAULT::da8w4_linear_impl,
170+
DEFAULT::float8_linear_prepack_impl,
171+
DEFAULT::float8_linear_impl,
172+
DEFAULT::_scaled_embedding_bag_impl,
173+
DEFAULT::_qscaled_dot_product_cpu};
174+
return d;
175+
}();
176+
return dispatcher;
177+
}
178+
179+
declare_da8w4_linear_prepack_impl {
180+
return get_kernel_dispatcher().da8w4_linear_prepack(weight, scales, qzeros);
170181
}
171182

172183
declare_da8w4_linear_impl {
173-
#if defined(BUILD_AVX10_2)
174-
if (kHasAVX10_2) {
175-
return AVX10_2::call_da8w4_linear_impl();
176-
}
177-
#endif
178-
#if defined(BUILD_AVX512)
179-
if (kHasAVX512) {
180-
return AVX512::call_da8w4_linear_impl();
181-
}
182-
#endif
183-
return DEFAULT::call_da8w4_linear_impl();
184+
return get_kernel_dispatcher().da8w4_linear(
185+
input, input_scales, input_qzeros, weight, weight_scales, weight_qzeros,
186+
compensation, bias, output_dtype);
184187
}
185188

186-
/********** FLOAT8 Linear Kernel Dispatch **********/
187189
declare_float8_linear_prepack_impl {
188-
#if defined(BUILD_AVX10_2)
189-
if (kHasAVX10_2) {
190-
return AVX10_2::call_float8_linear_prepack_impl();
191-
}
192-
#endif
193-
#if defined(BUILD_AVX512)
194-
if (kHasAVX512) {
195-
return AVX512::call_float8_linear_prepack_impl();
196-
}
197-
#endif
198-
return DEFAULT::call_float8_linear_prepack_impl();
190+
return get_kernel_dispatcher().float8_linear_prepack(weight, scales);
199191
}
200192

201193
declare_float8_linear_impl {
202-
#if defined(BUILD_AVX10_2)
203-
if (kHasAVX10_2) {
204-
return AVX10_2::call_float8_linear_impl();
205-
}
206-
#endif
207-
#if defined(BUILD_AVX512)
208-
if (kHasAVX512) {
209-
return AVX512::call_float8_linear_impl();
210-
}
211-
#endif
212-
return DEFAULT::call_float8_linear_impl();
194+
return get_kernel_dispatcher().float8_linear(
195+
input, input_scales, weight, weight_scales, bias, output_dtype);
213196
}
214197

215-
/********** Scaled Embedding Bag Kernel Dispatch **********/
216198
declare_scaled_embedding_bag_impl {
217-
#if defined(BUILD_AVX10_2)
218-
if (kHasAVX10_2) {
219-
return AVX10_2::call_scaled_embedding_bag_impl();
220-
}
221-
#endif
222-
#if defined(BUILD_AVX512)
223-
if (kHasAVX512) {
224-
return AVX512::call_scaled_embedding_bag_impl();
225-
}
226-
#endif
227-
return DEFAULT::call_scaled_embedding_bag_impl();
199+
return get_kernel_dispatcher().scaled_embedding_bag(
200+
qweight, indices, offsets, w_scales, o_scale, mode, include_last_offset,
201+
output_dtype);
228202
}
229203

230-
/********** Quantized SDPA Kernel **********/
231204
declare_qscaled_dot_product_impl {
232-
#if defined(BUILD_AVX10_2)
233-
if (kHasAVX10_2) {
234-
return AVX10_2::call_qscaled_dot_product_impl();
235-
}
236-
#endif
237-
#if defined(BUILD_AVX512)
238-
if (kHasAVX512) {
239-
return AVX512::call_qscaled_dot_product_impl();
240-
}
241-
#endif
242-
return DEFAULT::call_qscaled_dot_product_impl();
205+
return get_kernel_dispatcher().qscaled_dot_product(
206+
query, key, value, attn_mask, dropout_p, is_causal, scale, q_scale, q_zp,
207+
k_scale, k_zp, v_scale, v_zp, a_scale, a_zp, o_scale, o_zp);
243208
}
244209

245210
TORCH_LIBRARY_IMPL(torchao, CPU, m) {

0 commit comments

Comments
 (0)