From 83415b926b81bad566cca02622b59dd65b62e512 Mon Sep 17 00:00:00 2001 From: Ivan Petrov Date: Wed, 17 Jun 2026 02:05:30 +0300 Subject: [PATCH 1/7] Added softmax_scale in attentions. Separated attention_core call from attention with Q,K,V layers. Added tests for attentions. --- src/prime_rl/trainer/models/layers/attn.py | 294 +++++++++++++----- .../trainer/models/layers/ring_attn.py | 14 +- .../trainer/models/layers/ulysses_attn.py | 3 +- 3 files changed, 238 insertions(+), 73 deletions(-) diff --git a/src/prime_rl/trainer/models/layers/attn.py b/src/prime_rl/trainer/models/layers/attn.py index 5fa6e4bc6d..c62e27ab09 100644 --- a/src/prime_rl/trainer/models/layers/attn.py +++ b/src/prime_rl/trainer/models/layers/attn.py @@ -12,19 +12,25 @@ # flash-attention-2 try: from flash_attn import flash_attn_varlen_func + from flash_attn import flash_attn_func as fa2_func except ImportError: flash_attn_varlen_func = None # type: ignore + fa2_func = None # type: ignore # flash-attention-3 try: from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + from flash_attn_interface import flash_attn_func as fa3_func except ImportError: flash_attn_3_varlen_func = None # type: ignore + fa3_func = None # type: ignore try: from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func + from flash_attn.cute import flash_attn_func as fa4_func except ImportError: flash_attn_4_varlen_func = None # type: ignore + fa4_func = None @dataclass @@ -39,6 +45,11 @@ class AttentionConfig: rms_norm_eps: float qk_norm_type: Literal["per_head", "per_layer"] = "per_head" output_bias: bool = False + scaling: float | None = None + + def __post_init__(self): + if self.scaling is None: + self.scaling = self.head_dim**-0.5 # TODO: Does torch compile support config._attn_implementation forking? @@ -46,55 +57,52 @@ class AttentionConfig: # Otherwise, do ABC or something to make the signatures match -class FlashAttention(nn.Module): - """Flash Attention""" +class FlashAttentionCore(nn.Module): + """Plain Flash Attention.""" - _funcs = { + _funcs_varlen = { 2: flash_attn_varlen_func, 3: flash_attn_3_varlen_func, 4: flash_attn_4_varlen_func, } + _funcs = { + 2: fa2_func, + 3: fa3_func, + 4: fa4_func, + } + def __init__(self, config: AttentionConfig, flash_attn_version: int = 2): super().__init__() - self.head_dim = config.head_dim - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 + self.scaling = config.scaling self.is_causal = config.is_causal - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.output_bias) - self.use_qk_norm = config.use_qk_norm - self.qk_norm_type = config.qk_norm_type - if self.use_qk_norm: - if self.qk_norm_type == "per_layer": - self.q_norm = RMSNorm( - RMSNormConfig(hidden_size=config.num_attention_heads * self.head_dim, eps=config.rms_norm_eps) - ) - self.k_norm = RMSNorm( - RMSNormConfig(hidden_size=config.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps) - ) - else: - self.q_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) - self.k_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) - self._flash_attn_version = flash_attn_version - self.func = self._funcs[flash_attn_version] + self.att_core_func = self._funcs[flash_attn_version] + self.func = self._funcs_varlen[flash_attn_version] self._flash_attn_call = self.func if self._flash_attn_version == 4: self._flash_attn_call = torch._dynamo.disable(self.func) - def _compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens, max_seqlen): - """Run the flash attention kernel. q/k/v are [total_tokens, heads, dim].""" - kwargs: dict = {"causal": True} + def _fa_installed(self): + """Checks that flash attention is installed.""" + return self._funcs_varlen[self._flash_attn_version] is not None + + def _compute_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens, + max_seqlen, + softmax_scale, + ): + """Run the varlen flash attention kernel. q/k/v are [total_tokens, heads, dim]. + + When running Ring attention or Ulysses this method will be patched. + """ + + kwargs: dict = {"causal": True, "softmax_scale": softmax_scale} sliding_window = getattr(self, "sliding_window", None) if sliding_window is not None: kwargs["window_size"] = (sliding_window - 1, 0) @@ -105,7 +113,9 @@ def _compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kwargs["cu_seqlens_k"] = cu_seqlens out = self._flash_attn_call(q, k, v, **kwargs) else: - out = self._flash_attn_call(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, **kwargs) + out = self._flash_attn_call( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, **kwargs + ) if isinstance(out, tuple): out = out[0] return out @@ -115,11 +125,138 @@ def _attention_core( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: int | None = None, max_seqlen: int | None = None, + softmax_scale: int | None = None, ) -> torch.Tensor: - out = self._compute_attention(query_states[0], key_states[0], value_states[0], cu_seqlens, max_seqlen) - return out.contiguous().view(1, out.shape[0], -1) + # q,k,v - [bs, sl, nh, hdim] + + if softmax_scale is None: + softmax_scale = self.scaling + + bs = query_states.shape[0] + if cu_seqlens is None: + # Non-varlen: q/k/v are [bs, sl, nh, hdim]; FA returns same shape. + out = self.att_core_func( + query_states, + key_states, + value_states, + causal=True, + softmax_scale=softmax_scale, + ) + if isinstance(out, tuple): + # 'fa4' returns tuple + out = out[0] + return out.view(out.shape[0], out.shape[1], -1) + elif bs == 1: + # Varlen: FA expects [total_tokens, nh, hdim] and returns the same. + out = self._compute_attention( + query_states[0], + key_states[0], + value_states[0], + cu_seqlens, + max_seqlen, + softmax_scale=softmax_scale, + ) + if isinstance(out, tuple): + # 'fa4' returns tuple + out = out[0] + # Reshape back to [1, total_tokens, hidden] so o_proj sees a 3D tensor. + return out.contiguous().view(1, out.shape[0], -1) + + raise NotImplementedError("varlen attention with bs > 1 is not supported") + + +class SDPAAttentionCore(nn.Module): + """Plain SDPA Attention.""" + + def __init__(self, config: AttentionConfig): + super().__init__() + + self.head_dim = config.head_dim + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = config.scaling + self.is_causal = config.is_causal + + def _attention_core( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + softmax_scale: int | None = None, + ) -> torch.Tensor: + + if softmax_scale is None: + softmax_scale = self.scaling + + key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) + value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) + out = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + scale=softmax_scale, + is_causal=self.is_causal, + ) + out = out.transpose(1, 2).contiguous() + return out.view(out.shape[0], out.shape[1], -1) + + +class FlashAttention(FlashAttentionCore): + """Flash Attention with Q,K,V""" + + def __init__(self, config: AttentionConfig, flash_attn_version: int = 2): + super().__init__(config, flash_attn_version) + self.head_dim = config.head_dim + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.output_bias, + ) + self.use_qk_norm = config.use_qk_norm + self.qk_norm_type = config.qk_norm_type + if self.use_qk_norm: + if self.qk_norm_type == "per_layer": + self.q_norm = RMSNorm( + RMSNormConfig( + hidden_size=config.num_attention_heads * self.head_dim, + eps=config.rms_norm_eps, + ) + ) + self.k_norm = RMSNorm( + RMSNormConfig( + hidden_size=config.num_key_value_heads * self.head_dim, + eps=config.rms_norm_eps, + ) + ) + else: + self.q_norm = RMSNorm( + RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps) + ) + self.k_norm = RMSNorm( + RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps) + ) def attn_projections( self, @@ -151,7 +288,9 @@ def attn_projections( if position_embeddings is not None: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) # TODO: Can we optimize the rotary application instead of double transpose? query_states = query_states.transpose(1, 2) @@ -170,7 +309,9 @@ def forward( cu_seqlens: torch.LongTensor | None = None, max_seqlen: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - query_states, key_states, value_states = self.attn_projections(hidden_states, position_embeddings) + query_states, key_states, value_states = self.attn_projections( + hidden_states, position_embeddings + ) attn_output = self._attention_core( query_states, @@ -183,39 +324,55 @@ def forward( return attn_output, None -class SDPAAttention(nn.Module): +class SDPAAttention(SDPAAttentionCore): """SDPA Attention""" def __init__(self, config: AttentionConfig): - super().__init__() - self.head_dim = config.head_dim - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.is_causal = config.is_causal + super().__init__(config) self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, ) self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, ) self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.output_bias, ) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.output_bias) self.use_qk_norm = config.use_qk_norm self.qk_norm_type = config.qk_norm_type if self.use_qk_norm: if self.qk_norm_type == "per_layer": self.q_norm = RMSNorm( - RMSNormConfig(hidden_size=config.num_attention_heads * self.head_dim, eps=config.rms_norm_eps) + RMSNormConfig( + hidden_size=config.num_attention_heads * self.head_dim, + eps=config.rms_norm_eps, + ) ) self.k_norm = RMSNorm( - RMSNormConfig(hidden_size=config.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps) + RMSNormConfig( + hidden_size=config.num_key_value_heads * self.head_dim, + eps=config.rms_norm_eps, + ) ) else: - self.q_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) - self.k_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) + self.q_norm = RMSNorm( + RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps) + ) + self.k_norm = RMSNorm( + RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps) + ) def attn_projections( self, @@ -247,22 +404,12 @@ def attn_projections( if position_embeddings is not None: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) return query_states, key_states, value_states - def _attention_core( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - ) -> torch.Tensor: - key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) - value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) - out = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True) - out = out.transpose(1, 2).contiguous() - return out.view(out.shape[0], out.shape[1], -1) - def output_proj(self, attn_output: torch.Tensor) -> torch.Tensor: return self.o_proj(attn_output) @@ -273,7 +420,9 @@ def forward( cu_seqlens: torch.LongTensor | None = None, max_seqlen: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - query_states, key_states, value_states = self.attn_projections(hidden_states, position_embeddings) + query_states, key_states, value_states = self.attn_projections( + hidden_states, position_embeddings + ) attn_output = self._attention_core(query_states, key_states, value_states) attn_output = self.output_proj(attn_output) @@ -305,7 +454,9 @@ def substitute_ring_attn( else: ring_func = llama3_flash_attn_varlen_func - def _ring_compute_attention(self, q, k, v, cu_seqlens, max_seqlen): + def _ring_compute_attention( + self, q, k, v, cu_seqlens, max_seqlen, softmax_scale=None + ): from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS window_size = (-1, -1) @@ -326,6 +477,7 @@ def _ring_compute_attention(self, q, k, v, cu_seqlens, max_seqlen): window_size=window_size, group=process_group, heads_k_stride=heads_k_stride, + softmax_scale=softmax_scale, ) if isinstance(out, tuple): out = out[0] @@ -337,6 +489,8 @@ def _ring_compute_attention(self, q, k, v, cu_seqlens, max_seqlen): AfmoeFlashAttention._compute_attention = _ring_compute_attention - from prime_rl.trainer.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeGatedFlashAttention + from prime_rl.trainer.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeGatedFlashAttention, + ) Qwen3_5MoeGatedFlashAttention._compute_attention = _ring_compute_attention diff --git a/src/prime_rl/trainer/models/layers/ring_attn.py b/src/prime_rl/trainer/models/layers/ring_attn.py index cdb1774200..d2fecb82dd 100644 --- a/src/prime_rl/trainer/models/layers/ring_attn.py +++ b/src/prime_rl/trainer/models/layers/ring_attn.py @@ -120,6 +120,7 @@ def forward( group_name: str, window_size_left: int = -1, window_size_right: int = -1, + softmax_scale: int | None = None ) -> torch.Tensor: group = dist.group.WORLD for pg in dist.distributed_c10d._world.pg_map: @@ -129,7 +130,10 @@ def forward( local_k_slice = slice(local_k_slice_start, local_k_slice_stop) window_size = (window_size_left, window_size_right) - softmax_scale = q.shape[-1] ** (-0.5) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out_list = [] lse_list = [] @@ -292,6 +296,7 @@ def ring_fa3_varlen_func( heads_k_stride: int, group: dist.ProcessGroup, window_size: tuple[int, int] = (-1, -1), + softmax_scale: int | None = None ) -> torch.Tensor: return _RingFA3Varlen.apply( q, @@ -308,6 +313,7 @@ def ring_fa3_varlen_func( group.group_name, window_size[0], window_size[1], + softmax_scale ) @@ -415,6 +421,7 @@ def forward( group_name: str, window_size_left: int = -1, window_size_right: int = -1, + softmax_scale: int | None = None, ) -> torch.Tensor: group = dist.group.WORLD for pg in dist.distributed_c10d._world.pg_map: @@ -424,7 +431,8 @@ def forward( local_k_slice = slice(local_k_slice_start, local_k_slice_stop) window_size = (window_size_left, window_size_right) - softmax_scale = q.shape[-1] ** (-0.5) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) out_list = [] lse_list = [] @@ -587,6 +595,7 @@ def ring_fa4_varlen_func( heads_k_stride: int, group: dist.ProcessGroup, window_size: tuple[int, int] = (-1, -1), + softmax_scale: int | None = None ) -> torch.Tensor: return _RingFA4Varlen.apply( q, @@ -603,4 +612,5 @@ def ring_fa4_varlen_func( group.group_name, window_size[0], window_size[1], + softmax_scale, ) diff --git a/src/prime_rl/trainer/models/layers/ulysses_attn.py b/src/prime_rl/trainer/models/layers/ulysses_attn.py index 39139e4098..b89936ee1c 100644 --- a/src/prime_rl/trainer/models/layers/ulysses_attn.py +++ b/src/prime_rl/trainer/models/layers/ulysses_attn.py @@ -150,7 +150,7 @@ def substitute_ulysses_attn( flash_attn_version = 2 - def _ulysses_compute_attention(self, q, k, v, cu_seqlens, max_seqlen): + def _ulysses_compute_attention(self, q, k, v, cu_seqlens, max_seqlen, softmax_scale=None): # cu_seqlens / max_seqlen passed in are for the *local* sharded sequence; # ulysses needs the *full* ones (each rank holds the full seq after a2a). cu_seqlens_full = ULYSSES_PARAMS["cu_seqlens"] @@ -175,6 +175,7 @@ def _ulysses_compute_attention(self, q, k, v, cu_seqlens, max_seqlen): cp_size=cp_size, flash_attn_version=flash_attn_version, window_size=window_size, + softmax_scale = softmax_scale ) from prime_rl.trainer.models.layers.attn import FlashAttention From b1708fd110114a5a26d5c75fae9d8d510f35388d Mon Sep 17 00:00:00 2001 From: Ivan Petrov Date: Wed, 17 Jun 2026 02:13:48 +0300 Subject: [PATCH 2/7] added torch dynamo disable --- src/prime_rl/trainer/models/layers/attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/prime_rl/trainer/models/layers/attn.py b/src/prime_rl/trainer/models/layers/attn.py index c62e27ab09..4c3b0118e7 100644 --- a/src/prime_rl/trainer/models/layers/attn.py +++ b/src/prime_rl/trainer/models/layers/attn.py @@ -81,8 +81,10 @@ def __init__(self, config: AttentionConfig, flash_attn_version: int = 2): self.att_core_func = self._funcs[flash_attn_version] self.func = self._funcs_varlen[flash_attn_version] self._flash_attn_call = self.func + if self._flash_attn_version == 4: self._flash_attn_call = torch._dynamo.disable(self.func) + self.att_core_func = torch._dynamo.disable(self.att_core_func) def _fa_installed(self): """Checks that flash attention is installed.""" From 40d9bc58ccc9b1557d52b071511ff04dcc10ea0a Mon Sep 17 00:00:00 2001 From: Ivan Petrov Date: Wed, 17 Jun 2026 02:22:53 +0300 Subject: [PATCH 3/7] fixing returns in RingFA --- .../trainer/models/layers/ring_attn.py | 8 +- tests/unit/train/models/test_attention.py | 344 ++++++++++++++++++ 2 files changed, 348 insertions(+), 4 deletions(-) create mode 100644 tests/unit/train/models/test_attention.py diff --git a/src/prime_rl/trainer/models/layers/ring_attn.py b/src/prime_rl/trainer/models/layers/ring_attn.py index d2fecb82dd..b7bbd12f9c 100644 --- a/src/prime_rl/trainer/models/layers/ring_attn.py +++ b/src/prime_rl/trainer/models/layers/ring_attn.py @@ -279,8 +279,8 @@ def backward(ctx, dout: torch.Tensor): # Grads for: q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, # local_k_slice_start, local_k_slice_stop, heads_k_stride, causal, group_name, - # window_size_left, window_size_right - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + # window_size_left, window_size_right, softmax_scale + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, def ring_fa3_varlen_func( @@ -578,8 +578,8 @@ def backward(ctx, dout: torch.Tensor): # Grads for: q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, # local_k_slice_start, local_k_slice_stop, heads_k_stride, causal, group_name, - # window_size_left, window_size_right - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + # window_size_left, window_size_right, softmax_scale + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None def ring_fa4_varlen_func( diff --git a/tests/unit/train/models/test_attention.py b/tests/unit/train/models/test_attention.py new file mode 100644 index 0000000000..db18bb3c6b --- /dev/null +++ b/tests/unit/train/models/test_attention.py @@ -0,0 +1,344 @@ +"""Test attention implementations against SDPA reference.""" + +import tempfile +from pathlib import Path +import shutil +from typing import Generator + + +import torch +import torch.distributed as dist + +from prime_rl.trainer.models.layers.attn import ( + AttentionConfig, + FlashAttention, + SDPAAttention, + substitute_ring_attn, +) + +from prime_rl.trainer.models.layers.ulysses_attn import ( + substitute_ulysses_attn, + update_ulysses_params, +) + +from itertools import accumulate +from dataclasses import dataclass, field +import pytest + +import contextlib + + +@contextlib.contextmanager +def preserve_compute_attention(): + + from prime_rl.trainer.models.layers.attn import FlashAttention + from prime_rl.trainer.models.afmoe.modeling_afmoe import AfmoeFlashAttention + from prime_rl.trainer.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeGatedFlashAttention, + ) + + originals = { + cls: cls._compute_attention + for cls in (FlashAttention, AfmoeFlashAttention, Qwen3_5MoeGatedFlashAttention) + } + try: + yield + finally: + for cls, method in originals.items(): + cls._compute_attention = method + + +@pytest.fixture(scope="module") +def single_proc_group(): + if dist.is_initialized(): + yield dist.group.WORLD + return + + tmp_dir = Path(tempfile.mkdtemp(prefix="prime_rl_test_pg_")) + store_path = tmp_dir / "pg_store" + dist.init_process_group( + backend="gloo", + init_method=f"file://{store_path}", + world_size=1, + rank=0, + ) + try: + yield dist.group.WORLD + finally: + dist.destroy_process_group() + shutil.rmtree(tmp_dir, ignore_errors=True) + + +def patch_hf_adapter_params( + cu_seqlens: torch.Tensor, max_seqlen: int, local_k_slice: slice +) -> None: + from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS + + DATA_PARAMS["cu_seqlens_q"] = cu_seqlens + DATA_PARAMS["cu_seqlens_k"] = cu_seqlens + DATA_PARAMS["max_seqlen_q"] = max_seqlen + DATA_PARAMS["max_seqlen_k"] = max_seqlen + DATA_PARAMS["local_k_slice"] = local_k_slice + + +def copy_weights(attn_orig, attn_new): + # Copy weights to ensure same parameters + attn_new.load_state_dict(attn_orig.state_dict(), strict=True) + return attn_new + + +def iter_fa_versions() -> Generator[tuple[int, str], None, None]: + yield (2, "flash_attention_2") + yield (3, "flash_attention_3") + yield (4, "fa4") + + +@dataclass +class AttentionTestsInputs: + seq_lens: list[int] + cu_seqlens: torch.Tensor # cummulative seq lenths + max_seqlen: int + config: AttentionConfig + dtype: torch.dtype + hidden_states: torch.Tensor # shape [1, sum(seq_lens), hidden_size] + hidden_states_sdpa: torch.Tensor # [bs, max_seqlen, hidden_size] + sdpa_output: torch.Tensor # [bs, max_seqlen, hid_dim] + sdpa_output_packed: torch.Tensor # [1, sum(seq_lens), hidden_size] + sdpa_attn: ( + torch.nn.Module + ) # Prime-rl sdpa attention implementation with q,k,v layers + total_tokens: int = field(init=False) + + def __post_init__(self): + self.total_tokens = sum(self.seq_lens) + + +def generate_test_inputs( + seq_lens: list[int] = [128, 64], + softmax_scale: float | None = None, + seed: int = 42, + dtype=torch.bfloat16, + device="cuda", +) -> AttentionTestsInputs: + """ + Args: + seq_lens (list[int]): lengths of each sequence + ... + """ + torch.manual_seed(seed) + + cum_sl = [0] + list(accumulate(seq_lens)) + bs = len(seq_lens) + hidden_size = 4096 + head_dim = 128 + num_attention_heads = 32 + num_key_value_heads = 8 + + config = AttentionConfig( + hidden_size=hidden_size, + head_dim=head_dim, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + is_causal=True, + attention_bias=False, + use_qk_norm=False, + rms_norm_eps=1e-5, + scaling=softmax_scale, + ) + + # Create inputs + total_tokens = sum(seq_lens) + hidden_states = torch.randn( + 1, total_tokens, hidden_size, dtype=dtype, device=device + ) + cu_seqlens = torch.tensor(cum_sl, dtype=torch.int32, device=device) + max_seqlen = max(seq_lens) + + # Instantiate SDPA and FA3 attention + sdpa_attn = SDPAAttention(config).cuda().to(dtype) + + # --- Convert to padded (SDPA-compatible) format --- + hidden_states_sdpa = torch.zeros( + bs, max_seqlen, hidden_size, dtype=dtype, device=device + ) + + start = 0 + for i, sl in enumerate(seq_lens): + hidden_states_sdpa[i, :sl] = hidden_states[0, start : start + sl] + start += sl + + # Get SDPA output + sdpa_output, _ = sdpa_attn(hidden_states_sdpa) + + # SDPA output is [bs, seq_len, hidden]; varlen FA output is [1, total_tokens, hidden]. + sdpa_output_packed = torch.zeros( + 1, total_tokens, hidden_size, dtype=dtype, device=device + ) + + start = 0 + for i, sl in enumerate(seq_lens): + sdpa_output_packed[0, start : start + sl] = sdpa_output[i, :sl] + start += sl + + attn_inputs = AttentionTestsInputs( + seq_lens=seq_lens, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + config=config, + dtype=dtype, + hidden_states=hidden_states, + hidden_states_sdpa=hidden_states_sdpa, + sdpa_output=sdpa_output, + sdpa_output_packed=sdpa_output_packed, + sdpa_attn=sdpa_attn, + ) + + return attn_inputs + + +@pytest.mark.parametrize(("softmax_scale"), [(None), (1 / 15)]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_flash_attention_variants_vs_sdpa(softmax_scale: float | None): + """Test FA2, FA3, FA4 implementations match SDPA.""" + + attn_inputs = generate_test_inputs( + seq_lens=[64, 64, 64], softmax_scale=softmax_scale, seed=55 + ) + + dtype = attn_inputs.dtype + config = attn_inputs.config + sdpa_attn = attn_inputs.sdpa_attn + hidden_states_sdpa = attn_inputs.hidden_states_sdpa + sdpa_output = attn_inputs.sdpa_output + + for fa_ver, fa_name in iter_fa_versions(): + + print(f"Processing fa={fa_ver}") + fa_attn = FlashAttention(config, flash_attn_version=fa_ver).cuda().to(dtype) + if not fa_attn._fa_installed(): + # fa not installed + print(f"check skipped: {fa_name} not installed") + continue + + fa_attn = copy_weights(sdpa_attn, fa_attn) + fa_output, _ = fa_attn(hidden_states_sdpa, cu_seqlens=None, max_seqlen=None) + + diff = torch.abs(fa_output.float() - sdpa_output.float()).max() + torch.testing.assert_close( + fa_output.float(), + sdpa_output.float(), + atol=1e-2, + rtol=0, + msg=f"FA{fa_ver} output differs by {diff}", + ) + + +@pytest.mark.parametrize(("softmax_scale"), [(None), (1 / 15)]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_ring_varlen_vs_sdpa(softmax_scale: float | None, single_proc_group): + """Test ring varlen flash attention against SDPA.""" + + attn_inputs = generate_test_inputs( + seq_lens=[128, 64], softmax_scale=softmax_scale, seed=83 + ) + + dtype = attn_inputs.dtype + config = attn_inputs.config + sdpa_attn = attn_inputs.sdpa_attn + hidden_states = attn_inputs.hidden_states + sdpa_output_packed = attn_inputs.sdpa_output_packed + cu_seqlens = attn_inputs.cu_seqlens + max_seqlen = attn_inputs.max_seqlen + total_tokens = attn_inputs.total_tokens + + # local_k_slice indexes the *token* dim of the gathered kv_buffer, not heads. + # With world_size=1, local_k_slice = slice(0, total_tokens) + patch_hf_adapter_params( + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + local_k_slice=slice(0, total_tokens), + ) + + with preserve_compute_attention(): + for fa_ver, fa_name in iter_fa_versions(): + + print(f"Processing varlen fa={fa_ver}") + substitute_ring_attn( + single_proc_group, + heads_k_stride=config.num_key_value_heads, + attn_impl=fa_name, + ) + + varlen_fa_attn = ( + FlashAttention(config, flash_attn_version=fa_ver).cuda().to(dtype) + ) + + if not varlen_fa_attn._fa_installed(): + # fa not installed + print(f"check skipped: {fa_name} not installed") + continue + + varlen_fa_attn = copy_weights(sdpa_attn, varlen_fa_attn) + fa_output, _ = varlen_fa_attn( + hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + diff = torch.abs(fa_output.float() - sdpa_output_packed.float()).max() + torch.testing.assert_close( + fa_output.float(), + sdpa_output_packed.float(), + atol=1e-2, + rtol=0, + msg=f"Ring Varlen FA{fa_ver} output differs by {diff}", + ) + + +@pytest.mark.parametrize(("softmax_scale"), [(None), (1 / 15)]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_ulysses_varlen_vs_sdpa(softmax_scale: float | None, single_proc_group): + """Test Ulysses varlen attention against SDPA.""" + + attn_inputs = generate_test_inputs( + seq_lens=[128, 64], softmax_scale=softmax_scale, seed=91 + ) + + dtype = attn_inputs.dtype + config = attn_inputs.config + sdpa_attn = attn_inputs.sdpa_attn + hidden_states = attn_inputs.hidden_states + sdpa_output_packed = attn_inputs.sdpa_output_packed + cu_seqlens = attn_inputs.cu_seqlens + max_seqlen = attn_inputs.max_seqlen + total_tokens = attn_inputs.total_tokens + + # local_k_slice indexes the *token* dim of the gathered kv_buffer, not heads. + # With world_size=1, this rank owns all tokens. + update_ulysses_params(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + + with preserve_compute_attention(): + for fa_ver, fa_name in iter_fa_versions(): + print(f"Processing varlen fa={fa_ver}") + + substitute_ulysses_attn(single_proc_group, attn_impl=fa_name) + + varlen_fa_attn = ( + FlashAttention(config, flash_attn_version=fa_ver).cuda().to(dtype) + ) + + if not varlen_fa_attn._fa_installed(): + # fa not installed + print(f"check skipped: {fa_name} not installed") + continue + + varlen_fa_attn = copy_weights(sdpa_attn, varlen_fa_attn) + fa_output, _ = varlen_fa_attn( + hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + + diff = torch.abs(fa_output.float() - sdpa_output_packed.float()).max() + torch.testing.assert_close( + fa_output.float(), + sdpa_output_packed.float(), + atol=1e-2, + rtol=0, + msg=f"Ring Ulysses FA{fa_ver} output differs by {diff}", + ) From 7157f8636c01e27c1ee83282a94c93239c73bae3 Mon Sep 17 00:00:00 2001 From: Ivan Petrov Date: Wed, 17 Jun 2026 12:09:25 +0300 Subject: [PATCH 4/7] fixing backward return in RingFA --- src/prime_rl/trainer/models/layers/ring_attn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/prime_rl/trainer/models/layers/ring_attn.py b/src/prime_rl/trainer/models/layers/ring_attn.py index b7bbd12f9c..3ffb5aa06a 100644 --- a/src/prime_rl/trainer/models/layers/ring_attn.py +++ b/src/prime_rl/trainer/models/layers/ring_attn.py @@ -280,8 +280,7 @@ def backward(ctx, dout: torch.Tensor): # Grads for: q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, # local_k_slice_start, local_k_slice_stop, heads_k_stride, causal, group_name, # window_size_left, window_size_right, softmax_scale - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, - + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None def ring_fa3_varlen_func( q: torch.Tensor, From 39f3a82bcf209390804346e4eec43bdde80abc10 Mon Sep 17 00:00:00 2001 From: Ivan Petrov Date: Wed, 17 Jun 2026 14:44:43 +0300 Subject: [PATCH 5/7] added sliding window support into core attetion --- src/prime_rl/trainer/models/layers/attn.py | 58 ++++++++++++------ .../trainer/models/layers/ulysses_attn.py | 6 +- tests/unit/train/models/test_attention.py | 60 +++++++++++++++++++ 3 files changed, 103 insertions(+), 21 deletions(-) diff --git a/src/prime_rl/trainer/models/layers/attn.py b/src/prime_rl/trainer/models/layers/attn.py index 4c3b0118e7..7ec2a0f581 100644 --- a/src/prime_rl/trainer/models/layers/attn.py +++ b/src/prime_rl/trainer/models/layers/attn.py @@ -1,5 +1,5 @@ import functools -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal import torch @@ -46,6 +46,7 @@ class AttentionConfig: qk_norm_type: Literal["per_head", "per_layer"] = "per_head" output_bias: bool = False scaling: float | None = None + window_size_left: int | None = None def __post_init__(self): if self.scaling is None: @@ -76,17 +77,20 @@ def __init__(self, config: AttentionConfig, flash_attn_version: int = 2): super().__init__() self.scaling = config.scaling self.is_causal = config.is_causal + self.window_size_left = config.window_size_left self._flash_attn_version = flash_attn_version - self.att_core_func = self._funcs[flash_attn_version] - self.func = self._funcs_varlen[flash_attn_version] - self._flash_attn_call = self.func - + self.func1 = self._funcs[flash_attn_version] + self.func2 = self._funcs_varlen[flash_attn_version] + + self._flash_attn_call = self.func1 + self._varlen_flash_attn_call = self.func2 + if self._flash_attn_version == 4: - self._flash_attn_call = torch._dynamo.disable(self.func) - self.att_core_func = torch._dynamo.disable(self.att_core_func) + self._flash_attn_call = torch._dynamo.disable(self.func1) + self._varlen_flash_attn_call = torch._dynamo.disable(self.func2) - def _fa_installed(self): + def _fa_installed(self) -> bool: """Checks that flash attention is installed.""" return self._funcs_varlen[self._flash_attn_version] is not None @@ -105,17 +109,18 @@ def _compute_attention( """ kwargs: dict = {"causal": True, "softmax_scale": softmax_scale} - sliding_window = getattr(self, "sliding_window", None) - if sliding_window is not None: - kwargs["window_size"] = (sliding_window - 1, 0) + window_size_left = getattr(self, "window_size_left", None) + + if window_size_left is not None: + kwargs["window_size"] = (window_size_left - 1, 0) if self._flash_attn_version == 4: # FA4's flash_attn_varlen_func has qv as the 4th positional arg, # so cu_seqlens must be passed as keyword args to avoid misalignment. kwargs["cu_seqlens_q"] = cu_seqlens kwargs["cu_seqlens_k"] = cu_seqlens - out = self._flash_attn_call(q, k, v, **kwargs) + out = self._varlen_flash_attn_call(q, k, v, **kwargs) else: - out = self._flash_attn_call( + out = self._varlen_flash_attn_call( q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, **kwargs ) if isinstance(out, tuple): @@ -137,14 +142,23 @@ def _attention_core( softmax_scale = self.scaling bs = query_states.shape[0] + if cu_seqlens is None: # Non-varlen: q/k/v are [bs, sl, nh, hdim]; FA returns same shape. - out = self.att_core_func( + + window_size = (-1, -1) + window_size_left = getattr(self, "window_size_left", None) + + if window_size_left is not None: + window_size = (window_size_left - 1, 0) + + out = self._flash_attn_call( query_states, key_states, value_states, causal=True, softmax_scale=softmax_scale, + window_size=window_size, ) if isinstance(out, tuple): # 'fa4' returns tuple @@ -188,8 +202,12 @@ def _attention_core( key_states: torch.Tensor, value_states: torch.Tensor, softmax_scale: int | None = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: - + """SDPA Attention wrapper + Inputs: q,k,v - [bs, nh, sl, hdim] + Returns: out - [bs, sl, nh, hdim] + """ if softmax_scale is None: softmax_scale = self.scaling @@ -201,8 +219,12 @@ def _attention_core( value_states, scale=softmax_scale, is_causal=self.is_causal, + attn_mask=attn_mask, ) + out = out.transpose(1, 2).contiguous() + + # returns out - [bs, sl, nh, hdim] return out.view(out.shape[0], out.shape[1], -1) @@ -462,9 +484,9 @@ def _ring_compute_attention( from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS window_size = (-1, -1) - sliding_window = getattr(self, "sliding_window", None) - if sliding_window is not None: - window_size = (sliding_window - 1, 0) + window_size_left = getattr(self, "window_size_left", None) + if window_size_left is not None: + window_size = (window_size_left - 1, 0) out = ring_func( q, diff --git a/src/prime_rl/trainer/models/layers/ulysses_attn.py b/src/prime_rl/trainer/models/layers/ulysses_attn.py index b89936ee1c..578f5d4813 100644 --- a/src/prime_rl/trainer/models/layers/ulysses_attn.py +++ b/src/prime_rl/trainer/models/layers/ulysses_attn.py @@ -157,9 +157,9 @@ def _ulysses_compute_attention(self, q, k, v, cu_seqlens, max_seqlen, softmax_sc max_seqlen_full = ULYSSES_PARAMS["max_seqlen"] window_size = (-1, -1) - sliding_window = getattr(self, "sliding_window", None) - if sliding_window is not None: - window_size = (sliding_window - 1, 0) + window_size_left = getattr(self, "window_size_left", None) + if window_size_left is not None: + window_size = (window_size_left - 1, 0) return ulysses_flash_attn_varlen_func( flash_fn, diff --git a/tests/unit/train/models/test_attention.py b/tests/unit/train/models/test_attention.py index db18bb3c6b..e312aec046 100644 --- a/tests/unit/train/models/test_attention.py +++ b/tests/unit/train/models/test_attention.py @@ -12,7 +12,9 @@ from prime_rl.trainer.models.layers.attn import ( AttentionConfig, FlashAttention, + FlashAttentionCore, SDPAAttention, + SDPAAttentionCore, substitute_ring_attn, ) @@ -196,6 +198,64 @@ def generate_test_inputs( return attn_inputs +@pytest.mark.parametrize("window_size_left", [None, 32, 64]) # None = full attention +def test_sliding_window_attention(window_size_left: int | None): + + batch_size, seq_len, num_heads, head_dim = 2, 2048, 4, 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + k = torch.randn_like(q) + v = torch.randn_like(q) + + config = AttentionConfig( + hidden_size=1024, + head_dim=head_dim, + num_attention_heads=num_heads, + num_key_value_heads=num_heads, + attention_bias=False, + use_qk_norm=False, + rms_norm_eps=1e-5, + scaling=None, + is_causal=True if window_size_left is None else False, + window_size_left=window_size_left, + ) + + flash_attn = FlashAttentionCore(config, flash_attn_version=2).to(device, dtype) + flash_out = flash_attn._attention_core(q, k, v) + + # SDPA does NOT natively support sliding window, but you can construct the attention mask: + if window_size_left is not None: + + mask = torch.full( + (seq_len, seq_len), float("-inf"), device=q.device, dtype=q.dtype + ) + mask = torch.triu(mask, diagonal=1) # causal + + # sliding window: mask positions older than (window_size_left - 1) keys back + window = torch.tril( + torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), + diagonal=-window_size_left, + ) + mask.masked_fill_(window, float("-inf")) + else: + mask = None + + sdpa_attn = SDPAAttentionCore(config).cuda().to(dtype) + sdpa_out = sdpa_attn._attention_core( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=mask, + ) + + # Compare outputs + torch.testing.assert_close(flash_out, sdpa_out, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize(("softmax_scale"), [(None), (1 / 15)]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_flash_attention_variants_vs_sdpa(softmax_scale: float | None): From 34292900054b2d4fd5eae7d2db7e5c4d9992dc18 Mon Sep 17 00:00:00 2001 From: Ivan Petrov Date: Wed, 17 Jun 2026 14:51:26 +0300 Subject: [PATCH 6/7] reversed 'window_size_left' to 'sliding_window' --- src/prime_rl/trainer/models/layers/attn.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/prime_rl/trainer/models/layers/attn.py b/src/prime_rl/trainer/models/layers/attn.py index 7ec2a0f581..6a7dff42ee 100644 --- a/src/prime_rl/trainer/models/layers/attn.py +++ b/src/prime_rl/trainer/models/layers/attn.py @@ -77,7 +77,7 @@ def __init__(self, config: AttentionConfig, flash_attn_version: int = 2): super().__init__() self.scaling = config.scaling self.is_causal = config.is_causal - self.window_size_left = config.window_size_left + self.sliding_window = config.window_size_left self._flash_attn_version = flash_attn_version self.func1 = self._funcs[flash_attn_version] @@ -109,10 +109,11 @@ def _compute_attention( """ kwargs: dict = {"causal": True, "softmax_scale": softmax_scale} - window_size_left = getattr(self, "window_size_left", None) - if window_size_left is not None: - kwargs["window_size"] = (window_size_left - 1, 0) + sliding_window = getattr(self, "sliding_window", None) + if sliding_window is not None: + kwargs["window_size"] = (sliding_window - 1, 0) + if self._flash_attn_version == 4: # FA4's flash_attn_varlen_func has qv as the 4th positional arg, # so cu_seqlens must be passed as keyword args to avoid misalignment. @@ -147,10 +148,9 @@ def _attention_core( # Non-varlen: q/k/v are [bs, sl, nh, hdim]; FA returns same shape. window_size = (-1, -1) - window_size_left = getattr(self, "window_size_left", None) - - if window_size_left is not None: - window_size = (window_size_left - 1, 0) + sliding_window = getattr(self, "sliding_window", None) + if sliding_window is not None: + window_size = (sliding_window - 1, 0) out = self._flash_attn_call( query_states, @@ -484,9 +484,9 @@ def _ring_compute_attention( from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS window_size = (-1, -1) - window_size_left = getattr(self, "window_size_left", None) - if window_size_left is not None: - window_size = (window_size_left - 1, 0) + sliding_window = getattr(self, "sliding_window", None) + if sliding_window is not None: + window_size = (sliding_window - 1, 0) out = ring_func( q, From 0b83f366ea25c96080665ad70e715cf9735ebb21 Mon Sep 17 00:00:00 2001 From: Ivan Petrov Date: Wed, 17 Jun 2026 14:52:48 +0300 Subject: [PATCH 7/7] reversed 'window_size_left' to 'sliding_window' in ulyssses --- src/prime_rl/trainer/models/layers/ulysses_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/trainer/models/layers/ulysses_attn.py b/src/prime_rl/trainer/models/layers/ulysses_attn.py index 578f5d4813..b89936ee1c 100644 --- a/src/prime_rl/trainer/models/layers/ulysses_attn.py +++ b/src/prime_rl/trainer/models/layers/ulysses_attn.py @@ -157,9 +157,9 @@ def _ulysses_compute_attention(self, q, k, v, cu_seqlens, max_seqlen, softmax_sc max_seqlen_full = ULYSSES_PARAMS["max_seqlen"] window_size = (-1, -1) - window_size_left = getattr(self, "window_size_left", None) - if window_size_left is not None: - window_size = (window_size_left - 1, 0) + sliding_window = getattr(self, "sliding_window", None) + if sliding_window is not None: + window_size = (sliding_window - 1, 0) return ulysses_flash_attn_varlen_func( flash_fn,