3030 AttentionBackend ,
3131 apply_low_precision_attention ,
3232 )
33- from torchao .prototype .attention .fp8_fa3 .attention import fp8_fa3_sdpa
33+ from torchao .prototype .attention .fp8_fa3 .attention import (
34+ fp8_fa3_rope_sdpa ,
35+ fp8_fa3_sdpa ,
36+ )
37+
38+
39+ def _rope_cos_sin (S , D , device ):
40+ freqs = 1.0 / (10000.0 ** (torch .arange (0 , D , 2 , dtype = torch .float32 ) / D ))
41+ angles = torch .outer (torch .arange (S , dtype = torch .float32 ), freqs )
42+ cos_half = torch .cos (angles )
43+ sin_half = torch .sin (angles )
44+ cos = torch .cat ([cos_half , cos_half ], dim = - 1 ).to (device )
45+ sin = torch .cat ([sin_half , sin_half ], dim = - 1 ).to (device )
46+ return cos , sin
47+
48+
49+ def _apply_rope (x , cos , sin ):
50+ """NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D]."""
51+ D_HALF = x .shape [- 1 ] // 2
52+ rotate = torch .cat ([- x [..., D_HALF :], x [..., :D_HALF ]], dim = - 1 )
53+ return (
54+ x * cos .unsqueeze (0 ).unsqueeze (2 ) + rotate * sin .unsqueeze (0 ).unsqueeze (2 )
55+ ).to (x .dtype )
3456
3557
3658class SimpleAttentionModel (nn .Module ):
@@ -52,6 +74,30 @@ def forward(self, x):
5274 return self .out_proj (attn_out .transpose (1 , 2 ).contiguous ().view (B , S , - 1 ))
5375
5476
77+ class SimpleRoPEAttentionModel (nn .Module ):
78+ """Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA)."""
79+
80+ def __init__ (self , embed_dim , num_heads ):
81+ super ().__init__ ()
82+ self .num_heads = num_heads
83+ self .head_dim = embed_dim // num_heads
84+ self .q_proj = nn .Linear (embed_dim , embed_dim , bias = False )
85+ self .k_proj = nn .Linear (embed_dim , embed_dim , bias = False )
86+ self .v_proj = nn .Linear (embed_dim , embed_dim , bias = False )
87+ self .out_proj = nn .Linear (embed_dim , embed_dim , bias = False )
88+
89+ def forward (self , x , cos , sin ):
90+ B , S , _ = x .shape
91+ q = self .q_proj (x ).view (B , S , self .num_heads , self .head_dim )
92+ k = self .k_proj (x ).view (B , S , self .num_heads , self .head_dim )
93+ v = self .v_proj (x ).view (B , S , self .num_heads , self .head_dim )
94+ q = _apply_rope (q , cos , sin ).transpose (1 , 2 )
95+ k = _apply_rope (k , cos , sin ).transpose (1 , 2 )
96+ v = v .transpose (1 , 2 )
97+ attn_out = F .scaled_dot_product_attention (q , k , v , is_causal = True )
98+ return self .out_proj (attn_out .transpose (1 , 2 ).contiguous ().view (B , S , - 1 ))
99+
100+
55101@common_utils .instantiate_parametrized_tests
56102class TestFP8FA3Attention (TestCase ):
57103 @unittest .skipUnless (
@@ -83,6 +129,41 @@ def test_sdpa_accuracy(self, shape, dtype):
83129 f"SQNR { sqnr .item ():.2f} dB below 25 dB for shape={ shape } , dtype={ dtype } " ,
84130 )
85131
132+ @unittest .skipUnless (
133+ torch_version_at_least ("2.11.0" ) and _is_hopper () and _is_fa3_available (),
134+ "Requires PyTorch >= 2.11, Hopper GPU, and FA3" ,
135+ )
136+ @common_utils .parametrize ("shape" , [(2 , 1024 , 8 , 64 ), (1 , 1024 , 16 , 128 )])
137+ @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
138+ def test_rope_sdpa_accuracy (self , shape , dtype ):
139+ B , S , H , D = shape
140+ q = torch .randn (B , S , H , D , device = "cuda" , dtype = dtype )
141+ k = torch .randn (B , S , H , D , device = "cuda" , dtype = dtype )
142+ v = torch .randn (B , S , H , D , device = "cuda" , dtype = dtype )
143+ cos , sin = _rope_cos_sin (S , D , "cuda" )
144+
145+ with torch .no_grad ():
146+ out_ref = F .scaled_dot_product_attention (
147+ _apply_rope (q , cos , sin ).transpose (1 , 2 ),
148+ _apply_rope (k , cos , sin ).transpose (1 , 2 ),
149+ v .transpose (1 , 2 ),
150+ is_causal = False ,
151+ )
152+
153+ activate_flash_attention_impl ("FA3" )
154+ try :
155+ with torch .no_grad ():
156+ out_fp8 = fp8_fa3_rope_sdpa (q , k , v , cos , sin , is_causal = False )
157+ finally :
158+ restore_flash_attention_impl ()
159+
160+ sqnr = compute_error (out_ref , out_fp8 )
161+ self .assertGreater (
162+ sqnr .item (),
163+ 25.0 ,
164+ f"SQNR { sqnr .item ():.2f} dB below 25 dB for shape={ shape } , dtype={ dtype } " ,
165+ )
166+
86167 @unittest .skipUnless (
87168 torch_version_at_least ("2.11.0" ) and _is_hopper () and _is_fa3_available (),
88169 "Requires PyTorch >= 2.11, Hopper GPU, and FA3" ,
@@ -122,6 +203,48 @@ def test_monkey_patch_model(self, dtype):
122203 f"SQNR { sqnr .item ():.2f} dB below 20 dB for dtype={ dtype } " ,
123204 )
124205
206+ @unittest .skipUnless (
207+ torch_version_at_least ("2.11.0" ) and _is_hopper () and _is_fa3_available (),
208+ "Requires PyTorch >= 2.11, Hopper GPU, and FA3" ,
209+ )
210+ @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
211+ def test_rope_fusion_model (self , dtype ):
212+ embed_dim , num_heads = 512 , 8
213+ model = (
214+ SimpleRoPEAttentionModel (embed_dim , num_heads )
215+ .to (device = "cuda" , dtype = dtype )
216+ .eval ()
217+ )
218+ S = 128
219+ x = torch .randn (2 , S , embed_dim , device = "cuda" , dtype = dtype )
220+ cos , sin = _rope_cos_sin (S , embed_dim // num_heads , "cuda" )
221+
222+ with torch .no_grad ():
223+ out_ref = model (x , cos , sin )
224+
225+ fp8_model = (
226+ SimpleRoPEAttentionModel (embed_dim , num_heads )
227+ .to (device = "cuda" , dtype = dtype )
228+ .eval ()
229+ )
230+ fp8_model .load_state_dict (model .state_dict ())
231+ fp8_model = apply_low_precision_attention (
232+ fp8_model ,
233+ backend = AttentionBackend .FP8_FA3 ,
234+ fuse_rope_using_torch_compile = True ,
235+ )
236+ fp8_model = torch .compile (fp8_model , backend = fp8_model .compile_backend )
237+
238+ with torch .no_grad ():
239+ out_fp8 = fp8_model (x , cos , sin )
240+
241+ sqnr = compute_error (out_ref , out_fp8 )
242+ self .assertGreater (
243+ sqnr .item (),
244+ 20.0 ,
245+ f"SQNR { sqnr .item ():.2f} dB below 20 dB for dtype={ dtype } " ,
246+ )
247+
125248
126249if __name__ == "__main__" :
127250 run_tests ()
0 commit comments