Skip to content

Commit bbe615c

Browse files
authored
[CPU][QSDPA] Fix issue with strided input (#3794)
* [cpu][qsdpa] fix issue with strided input
1 parent e094ce3 commit bbe615c

3 files changed

Lines changed: 116 additions & 78 deletions

File tree

test/test_ops.py

Lines changed: 109 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -138,23 +138,27 @@ def test_quantized_scaled_dot_product_op(
138138
torch.manual_seed(1234)
139139
device = "cpu"
140140
if input_dtype == torch.uint8:
141-
q_scale = float(1.7907238006591797)
142-
k_scale = float(1.8039721250534058)
143-
v_scale = float(1.839004635810852)
144-
a_scale = float(0.003919653594493866)
145-
o_scale = float(1.8191684484481812)
146-
q_zp = int(127)
147-
k_zp = int(125)
148-
v_zp = int(127)
149-
a_zp = int(120)
150-
o_zp = int(128)
141+
q_scale, k_scale, v_scale = (
142+
float(1.7907238006591797),
143+
float(1.8039721250534058),
144+
float(1.839004635810852),
145+
)
146+
a_scale, o_scale = float(0.003919653594493866), float(1.8191684484481812)
147+
q_zp, k_zp, v_zp, a_zp, o_zp = (
148+
int(127),
149+
int(125),
150+
int(127),
151+
int(120),
152+
int(128),
153+
)
151154
atol, rtol = 1.0, 5e-6
152155
else:
153-
q_scale = float(5.96875)
154-
k_scale = float(5.78125)
155-
v_scale = float(0.98046875)
156-
a_scale = float(4.84375)
157-
o_scale = float(3.171875)
156+
q_scale, k_scale, v_scale = (
157+
float(5.96875),
158+
float(5.78125),
159+
float(0.98046875),
160+
)
161+
a_scale, o_scale = float(4.84375), float(3.171875)
158162
atol, rtol = 0.125, 5e-6
159163
q_shape = [batch_size, q_seq_len, n_head, head_dim]
160164
kv_shape = [batch_size, kv_seq_len, n_head, head_dim]
@@ -163,12 +167,8 @@ def test_quantized_scaled_dot_product_op(
163167
k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
164168
v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
165169
if input_dtype == torch.uint8:
166-
q *= 100
167-
k *= 100
168-
v *= 100
169-
q = q.to(input_dtype)
170-
k = k.to(input_dtype)
171-
v = v.to(input_dtype)
170+
q, k, v = (t * 100 for t in (q, k, v))
171+
q, k, v = (t.to(input_dtype) for t in (q, k, v))
172172
attn_mask = (
173173
torch.randn(mask_shape, dtype=mask_dtype, device=device)
174174
if mask_dtype is not None
@@ -181,72 +181,106 @@ def test_quantized_scaled_dot_product_op(
181181
attn_mask.clone() if mask_dtype is not None else None,
182182
)
183183

184+
common_kwargs = dict(
185+
dropout_p=0.0,
186+
is_causal=False,
187+
q_scale=q_scale,
188+
k_scale=k_scale,
189+
v_scale=v_scale,
190+
a_scale=a_scale,
191+
o_scale=o_scale,
192+
)
193+
184194
if input_dtype == torch.uint8:
195+
common_kwargs.update(q_zp=q_zp, k_zp=k_zp, v_zp=v_zp, a_zp=a_zp, o_zp=o_zp)
185196
math_ref = self._scaled_dot_product_int8_op_ref(
186-
q2,
187-
k2,
188-
v2,
189-
attn_mask=attn_mask,
190-
dropout_p=0.0,
191-
is_causal=False,
192-
q_scale=q_scale,
193-
q_zp=q_zp,
194-
k_scale=k_scale,
195-
k_zp=k_zp,
196-
v_scale=v_scale,
197-
v_zp=v_zp,
198-
a_scale=a_scale,
199-
a_zp=a_zp,
200-
o_scale=o_scale,
201-
o_zp=o_zp,
197+
q2, k2, v2, attn_mask=attn_mask, **common_kwargs
202198
)
203199
actual = torch.ops.torchao.qscaled_dot_product(
204-
q,
205-
k,
206-
v,
207-
attn_mask=attn_mask_2,
208-
dropout_p=0.0,
209-
is_causal=False,
210-
q_scale=q_scale,
211-
q_zp=q_zp,
212-
k_scale=k_scale,
213-
k_zp=k_zp,
214-
v_scale=v_scale,
215-
v_zp=v_zp,
216-
a_scale=a_scale,
217-
a_zp=a_zp,
218-
o_scale=o_scale,
219-
o_zp=o_zp,
200+
q, k, v, attn_mask=attn_mask_2, **common_kwargs
220201
)
221202
else:
222203
math_ref = self._scaled_dot_product_fp8_op_ref(
223-
q2,
224-
k2,
225-
v2,
226-
attn_mask=attn_mask,
227-
dropout_p=0.0,
228-
is_causal=False,
229-
q_scale=q_scale,
230-
k_scale=k_scale,
231-
v_scale=v_scale,
232-
a_scale=a_scale,
233-
o_scale=o_scale,
204+
q2, k2, v2, attn_mask=attn_mask, **common_kwargs
234205
)
235206
actual = torch.ops.torchao.qscaled_dot_product(
236-
q,
237-
k,
238-
v,
239-
attn_mask=attn_mask_2,
240-
dropout_p=0.0,
241-
is_causal=False,
242-
q_scale=q_scale,
243-
k_scale=k_scale,
244-
v_scale=v_scale,
245-
a_scale=a_scale,
246-
o_scale=o_scale,
207+
q, k, v, attn_mask=attn_mask_2, **common_kwargs
247208
)
248209
self.assertEqual(actual.float(), math_ref.float(), atol=atol, rtol=rtol)
249210

211+
@pytest.mark.skipif(
212+
not torch_version_at_least("2.7.0"),
213+
reason="quantized sdpa requires torch 2.7 or later",
214+
)
215+
@pytest.mark.skipif(not IS_LINUX, reason="only support on linux")
216+
@pytest.mark.skipif(
217+
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
218+
reason="cpp kernels not built",
219+
)
220+
@parametrize("input_dtype", [torch.uint8, torch.float8_e4m3fn])
221+
def test_quantized_scaled_dot_product_op_with_strided_inputs(
222+
self,
223+
input_dtype,
224+
):
225+
torch.manual_seed(1234)
226+
device = "cpu"
227+
if input_dtype == torch.uint8:
228+
q_scale, k_scale, v_scale = (
229+
float(1.7907238006591797),
230+
float(1.8039721250534058),
231+
float(1.839004635810852),
232+
)
233+
a_scale, o_scale = float(0.003919653594493866), float(1.8191684484481812)
234+
q_zp, k_zp, v_zp, a_zp, o_zp = (
235+
int(127),
236+
int(125),
237+
int(127),
238+
int(120),
239+
int(128),
240+
)
241+
atol, rtol = 1.0, 5e-6
242+
else:
243+
q_scale, k_scale, v_scale = (
244+
float(5.96875),
245+
float(5.78125),
246+
float(0.98046875),
247+
)
248+
a_scale, o_scale = float(4.84375), float(3.171875)
249+
atol, rtol = 0.125, 5e-6
250+
batch_size, seq_len, num_head, head_dim = 56, 100, 16, 32
251+
hidden_size = num_head * head_dim
252+
proj = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=False)
253+
input_shape = (batch_size, seq_len, num_head * head_dim)
254+
qkv_shape = (batch_size, seq_len, num_head, head_dim)
255+
input_tensor = torch.randn(input_shape, dtype=torch.float, device=device)
256+
hidden_state = proj(input_tensor)
257+
q, k, v = torch.split(hidden_state, hidden_size, dim=-1)
258+
q, k, v = (t.view(*qkv_shape).transpose(1, 2) for t in (q, k, v))
259+
if input_dtype == torch.uint8:
260+
q, k, v = (t * 100 for t in (q, k, v))
261+
q, k, v = (t.to(input_dtype) for t in (q, k, v))
262+
q2, k2, v2 = (t.clone() for t in (q, k, v))
263+
264+
common_kwargs = dict(
265+
attn_mask=None,
266+
dropout_p=0.0,
267+
is_causal=False,
268+
q_scale=q_scale,
269+
k_scale=k_scale,
270+
v_scale=v_scale,
271+
a_scale=a_scale,
272+
o_scale=o_scale,
273+
)
274+
if input_dtype == torch.uint8:
275+
common_kwargs.update(q_zp=q_zp, k_zp=k_zp, v_zp=v_zp, a_zp=a_zp, o_zp=o_zp)
276+
math_ref = self._scaled_dot_product_int8_op_ref(q2, k2, v2, **common_kwargs)
277+
actual = torch.ops.torchao.qscaled_dot_product(q, k, v, **common_kwargs)
278+
else:
279+
math_ref = self._scaled_dot_product_fp8_op_ref(q2, k2, v2, **common_kwargs)
280+
actual = torch.ops.torchao.qscaled_dot_product(q, k, v, **common_kwargs)
281+
assert actual.transpose(1, 2).is_contiguous(), "Output is not contiguous!"
282+
self.assertEqual(actual.float(), math_ref.float(), atol=atol, rtol=rtol)
283+
250284

251285
instantiate_parametrized_tests(TestOps)
252286

torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,11 +2525,15 @@ at::Tensor _qscaled_dot_product_cpu(
25252525
TORCH_CHECK(q_zp == 0 && k_zp == 0 && v_zp == 0 && a_zp == 0 && o_zp == 0,
25262526
"_qscaled_dot_product_cpu: Don't accept zero point for Float8_e4m3");
25272527
}
2528+
int64_t batch_size = query.size(0);
2529+
int64_t num_head = query.size(1);
2530+
int64_t q_seq_len = query.size(2);
2531+
int64_t head_size = query.size(3);
25282532

25292533
if (dtype == at::ScalarType::Byte) {
25302534
#ifdef CPU_CAPABILITY_AVX512
25312535
if (at::native::cpublas::could_pack(dtype)) {
2532-
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2);
2536+
at::Tensor output = at::empty({batch_size, q_seq_len, num_head, head_size}, query.options());
25332537
int8_sdpa_fused_kernel(output, query, key, value,
25342538
dropout_p, is_causal, attn_mask, scale,
25352539
q_scale, q_zp,
@@ -2554,7 +2558,7 @@ at::Tensor _qscaled_dot_product_cpu(
25542558
#if defined(CPUBLAS_BRGEMM_F8F8F32) && defined(CPU_CAPABILITY_AVX512)
25552559
// CPUBLAS_BRGEMM_F8F8F32 is defined if FP8 BRGEMM is supported in PyTorch CPUBlas.
25562560
if (at::native::cpublas::could_pack(dtype)) {
2557-
at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2);
2561+
at::Tensor output = at::empty({batch_size, q_seq_len, num_head, head_size}, query.options());
25582562
fp8_sdpa_fused_kernel(output, query, key, value,
25592563
dropout_p, is_causal, attn_mask, scale,
25602564
q_scale, k_scale,

torchao/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _(
152152
o_scale: float = 1.0,
153153
o_zp: int = 0,
154154
) -> Tensor:
155-
return query
155+
return query.transpose(1, 2).contiguous().transpose(1, 2)
156156

157157

158158
def rowwise_scaled_linear_sparse_cutlass_f8f8(

0 commit comments

Comments
 (0)