@@ -138,23 +138,27 @@ def test_quantized_scaled_dot_product_op(
138138 torch .manual_seed (1234 )
139139 device = "cpu"
140140 if input_dtype == torch .uint8 :
141- q_scale = float (1.7907238006591797 )
142- k_scale = float (1.8039721250534058 )
143- v_scale = float (1.839004635810852 )
144- a_scale = float (0.003919653594493866 )
145- o_scale = float (1.8191684484481812 )
146- q_zp = int (127 )
147- k_zp = int (125 )
148- v_zp = int (127 )
149- a_zp = int (120 )
150- o_zp = int (128 )
141+ q_scale , k_scale , v_scale = (
142+ float (1.7907238006591797 ),
143+ float (1.8039721250534058 ),
144+ float (1.839004635810852 ),
145+ )
146+ a_scale , o_scale = float (0.003919653594493866 ), float (1.8191684484481812 )
147+ q_zp , k_zp , v_zp , a_zp , o_zp = (
148+ int (127 ),
149+ int (125 ),
150+ int (127 ),
151+ int (120 ),
152+ int (128 ),
153+ )
151154 atol , rtol = 1.0 , 5e-6
152155 else :
153- q_scale = float (5.96875 )
154- k_scale = float (5.78125 )
155- v_scale = float (0.98046875 )
156- a_scale = float (4.84375 )
157- o_scale = float (3.171875 )
156+ q_scale , k_scale , v_scale = (
157+ float (5.96875 ),
158+ float (5.78125 ),
159+ float (0.98046875 ),
160+ )
161+ a_scale , o_scale = float (4.84375 ), float (3.171875 )
158162 atol , rtol = 0.125 , 5e-6
159163 q_shape = [batch_size , q_seq_len , n_head , head_dim ]
160164 kv_shape = [batch_size , kv_seq_len , n_head , head_dim ]
@@ -163,12 +167,8 @@ def test_quantized_scaled_dot_product_op(
163167 k = torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
164168 v = torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
165169 if input_dtype == torch .uint8 :
166- q *= 100
167- k *= 100
168- v *= 100
169- q = q .to (input_dtype )
170- k = k .to (input_dtype )
171- v = v .to (input_dtype )
170+ q , k , v = (t * 100 for t in (q , k , v ))
171+ q , k , v = (t .to (input_dtype ) for t in (q , k , v ))
172172 attn_mask = (
173173 torch .randn (mask_shape , dtype = mask_dtype , device = device )
174174 if mask_dtype is not None
@@ -181,72 +181,106 @@ def test_quantized_scaled_dot_product_op(
181181 attn_mask .clone () if mask_dtype is not None else None ,
182182 )
183183
184+ common_kwargs = dict (
185+ dropout_p = 0.0 ,
186+ is_causal = False ,
187+ q_scale = q_scale ,
188+ k_scale = k_scale ,
189+ v_scale = v_scale ,
190+ a_scale = a_scale ,
191+ o_scale = o_scale ,
192+ )
193+
184194 if input_dtype == torch .uint8 :
195+ common_kwargs .update (q_zp = q_zp , k_zp = k_zp , v_zp = v_zp , a_zp = a_zp , o_zp = o_zp )
185196 math_ref = self ._scaled_dot_product_int8_op_ref (
186- q2 ,
187- k2 ,
188- v2 ,
189- attn_mask = attn_mask ,
190- dropout_p = 0.0 ,
191- is_causal = False ,
192- q_scale = q_scale ,
193- q_zp = q_zp ,
194- k_scale = k_scale ,
195- k_zp = k_zp ,
196- v_scale = v_scale ,
197- v_zp = v_zp ,
198- a_scale = a_scale ,
199- a_zp = a_zp ,
200- o_scale = o_scale ,
201- o_zp = o_zp ,
197+ q2 , k2 , v2 , attn_mask = attn_mask , ** common_kwargs
202198 )
203199 actual = torch .ops .torchao .qscaled_dot_product (
204- q ,
205- k ,
206- v ,
207- attn_mask = attn_mask_2 ,
208- dropout_p = 0.0 ,
209- is_causal = False ,
210- q_scale = q_scale ,
211- q_zp = q_zp ,
212- k_scale = k_scale ,
213- k_zp = k_zp ,
214- v_scale = v_scale ,
215- v_zp = v_zp ,
216- a_scale = a_scale ,
217- a_zp = a_zp ,
218- o_scale = o_scale ,
219- o_zp = o_zp ,
200+ q , k , v , attn_mask = attn_mask_2 , ** common_kwargs
220201 )
221202 else :
222203 math_ref = self ._scaled_dot_product_fp8_op_ref (
223- q2 ,
224- k2 ,
225- v2 ,
226- attn_mask = attn_mask ,
227- dropout_p = 0.0 ,
228- is_causal = False ,
229- q_scale = q_scale ,
230- k_scale = k_scale ,
231- v_scale = v_scale ,
232- a_scale = a_scale ,
233- o_scale = o_scale ,
204+ q2 , k2 , v2 , attn_mask = attn_mask , ** common_kwargs
234205 )
235206 actual = torch .ops .torchao .qscaled_dot_product (
236- q ,
237- k ,
238- v ,
239- attn_mask = attn_mask_2 ,
240- dropout_p = 0.0 ,
241- is_causal = False ,
242- q_scale = q_scale ,
243- k_scale = k_scale ,
244- v_scale = v_scale ,
245- a_scale = a_scale ,
246- o_scale = o_scale ,
207+ q , k , v , attn_mask = attn_mask_2 , ** common_kwargs
247208 )
248209 self .assertEqual (actual .float (), math_ref .float (), atol = atol , rtol = rtol )
249210
211+ @pytest .mark .skipif (
212+ not torch_version_at_least ("2.7.0" ),
213+ reason = "quantized sdpa requires torch 2.7 or later" ,
214+ )
215+ @pytest .mark .skipif (not IS_LINUX , reason = "only support on linux" )
216+ @pytest .mark .skipif (
217+ "CPU" not in torch ._C ._dispatch_dump ("torchao::qscaled_dot_product" ),
218+ reason = "cpp kernels not built" ,
219+ )
220+ @parametrize ("input_dtype" , [torch .uint8 , torch .float8_e4m3fn ])
221+ def test_quantized_scaled_dot_product_op_with_strided_inputs (
222+ self ,
223+ input_dtype ,
224+ ):
225+ torch .manual_seed (1234 )
226+ device = "cpu"
227+ if input_dtype == torch .uint8 :
228+ q_scale , k_scale , v_scale = (
229+ float (1.7907238006591797 ),
230+ float (1.8039721250534058 ),
231+ float (1.839004635810852 ),
232+ )
233+ a_scale , o_scale = float (0.003919653594493866 ), float (1.8191684484481812 )
234+ q_zp , k_zp , v_zp , a_zp , o_zp = (
235+ int (127 ),
236+ int (125 ),
237+ int (127 ),
238+ int (120 ),
239+ int (128 ),
240+ )
241+ atol , rtol = 1.0 , 5e-6
242+ else :
243+ q_scale , k_scale , v_scale = (
244+ float (5.96875 ),
245+ float (5.78125 ),
246+ float (0.98046875 ),
247+ )
248+ a_scale , o_scale = float (4.84375 ), float (3.171875 )
249+ atol , rtol = 0.125 , 5e-6
250+ batch_size , seq_len , num_head , head_dim = 56 , 100 , 16 , 32
251+ hidden_size = num_head * head_dim
252+ proj = torch .nn .Linear (hidden_size , 3 * hidden_size , bias = False )
253+ input_shape = (batch_size , seq_len , num_head * head_dim )
254+ qkv_shape = (batch_size , seq_len , num_head , head_dim )
255+ input_tensor = torch .randn (input_shape , dtype = torch .float , device = device )
256+ hidden_state = proj (input_tensor )
257+ q , k , v = torch .split (hidden_state , hidden_size , dim = - 1 )
258+ q , k , v = (t .view (* qkv_shape ).transpose (1 , 2 ) for t in (q , k , v ))
259+ if input_dtype == torch .uint8 :
260+ q , k , v = (t * 100 for t in (q , k , v ))
261+ q , k , v = (t .to (input_dtype ) for t in (q , k , v ))
262+ q2 , k2 , v2 = (t .clone () for t in (q , k , v ))
263+
264+ common_kwargs = dict (
265+ attn_mask = None ,
266+ dropout_p = 0.0 ,
267+ is_causal = False ,
268+ q_scale = q_scale ,
269+ k_scale = k_scale ,
270+ v_scale = v_scale ,
271+ a_scale = a_scale ,
272+ o_scale = o_scale ,
273+ )
274+ if input_dtype == torch .uint8 :
275+ common_kwargs .update (q_zp = q_zp , k_zp = k_zp , v_zp = v_zp , a_zp = a_zp , o_zp = o_zp )
276+ math_ref = self ._scaled_dot_product_int8_op_ref (q2 , k2 , v2 , ** common_kwargs )
277+ actual = torch .ops .torchao .qscaled_dot_product (q , k , v , ** common_kwargs )
278+ else :
279+ math_ref = self ._scaled_dot_product_fp8_op_ref (q2 , k2 , v2 , ** common_kwargs )
280+ actual = torch .ops .torchao .qscaled_dot_product (q , k , v , ** common_kwargs )
281+ assert actual .transpose (1 , 2 ).is_contiguous (), "Output is not contiguous!"
282+ self .assertEqual (actual .float (), math_ref .float (), atol = atol , rtol = rtol )
283+
250284
251285instantiate_parametrized_tests (TestOps )
252286
0 commit comments