Skip to content

Commit 716c91c

Browse files
fix: prevent inductor from fusing away bf16→fp32 cast in RoPE (pytorch#2575)
With compile.enable=true, Qwen3 produces different outputs from eager (max diff ~1.56 in bfloat16). Inductor traces the whole transformer block as one graph and legally eliminates the .to(dtype=xq.dtype) downcast between q_norm/k_norm and RoPE, keeping the multiply-add in fp32. Valid algebra, wrong dtype boundary relative to eager. Fix borrowed from apply_rotary_emb_complex: upcast xq/xk to float32 before the multiply-add instead of downcasting cos/sin to match. The fp32 compute is now unconditional in the graph so Inductor has nothing to fuse away. Cast back with type_as at the end as before. Fixes Qwen3 and GPT-OSS (the only callers of apply_rotary_emb_cos_sin). RoPE now always computes in fp32 in eager too: slightly more accurate, matches HF Qwen3 behavior, no checkpoint impact.
1 parent 766f181 commit 716c91c

2 files changed

Lines changed: 74 additions & 4 deletions

File tree

tests/unit_tests/test_rope.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchtitan.models.common.rope import apply_rotary_emb_cos_sin
12+
13+
14+
class TestApplyRotaryEmbCosSin(unittest.TestCase):
15+
def setUp(self):
16+
torch.manual_seed(42)
17+
self.bsz = 2
18+
self.seqlen = 16
19+
self.n_heads = 4
20+
self.head_dim = 64
21+
self.xq = torch.randn(
22+
self.bsz, self.seqlen, self.n_heads, self.head_dim, dtype=torch.bfloat16
23+
)
24+
self.xk = torch.randn(
25+
self.bsz, self.seqlen, self.n_heads, self.head_dim, dtype=torch.bfloat16
26+
)
27+
self.rope_cache = torch.randn(
28+
self.seqlen, self.head_dim * 2, dtype=torch.float32
29+
)
30+
31+
def test_output_dtype_matches_input(self):
32+
xq_out, xk_out = apply_rotary_emb_cos_sin(self.xq, self.xk, self.rope_cache)
33+
self.assertEqual(xq_out.dtype, self.xq.dtype)
34+
self.assertEqual(xk_out.dtype, self.xk.dtype)
35+
36+
def test_output_shape_matches_input(self):
37+
xq_out, xk_out = apply_rotary_emb_cos_sin(self.xq, self.xk, self.rope_cache)
38+
self.assertEqual(xq_out.shape, self.xq.shape)
39+
self.assertEqual(xk_out.shape, self.xk.shape)
40+
41+
def test_computes_in_fp32(self):
42+
"""Output must match a reference computed entirely in float32.
43+
44+
Ensures inductor cannot fuse away the fp32 upcast when compiling
45+
adjacent ops (e.g. q_norm/k_norm) with the RoPE computation.
46+
"""
47+
xq_out, xk_out = apply_rotary_emb_cos_sin(self.xq, self.xk, self.rope_cache)
48+
49+
cos = self.rope_cache[..., : self.head_dim].unsqueeze(0).unsqueeze(2)
50+
sin = self.rope_cache[..., self.head_dim :].unsqueeze(0).unsqueeze(2)
51+
52+
def rotate_half(x):
53+
half = x.shape[-1] // 2
54+
return torch.cat([-x[..., half:], x[..., :half]], dim=-1)
55+
56+
xq_ref = (
57+
(self.xq.float() * cos) + (rotate_half(self.xq.float()) * sin)
58+
).bfloat16()
59+
xk_ref = (
60+
(self.xk.float() * cos) + (rotate_half(self.xk.float()) * sin)
61+
).bfloat16()
62+
63+
self.assertEqual((xq_out - xq_ref).abs().max().item(), 0.0)
64+
self.assertEqual((xk_out - xk_ref).abs().max().item(), 0.0)
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

torchtitan/models/common/rope.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,10 @@ def apply_rotary_emb_cos_sin(
347347
"""
348348
head_dim = xq.shape[-1]
349349
rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions)
350-
cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device)
351-
sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device)
352-
xq_out = (xq * cos) + (_rotate_half(xq) * sin)
353-
xk_out = (xk * cos) + (_rotate_half(xk) * sin)
350+
cos = rope_cache[..., :head_dim].to(device=xq.device)
351+
sin = rope_cache[..., head_dim:].to(device=xq.device)
352+
xq_f = xq.float()
353+
xk_f = xk.float()
354+
xq_out = (xq_f * cos) + (_rotate_half(xq_f) * sin)
355+
xk_out = (xk_f * cos) + (_rotate_half(xk_f) * sin)
354356
return xq_out.type_as(xq), xk_out.type_as(xk)

0 commit comments

Comments
 (0)