From 314f4f81feef41cc9af69ae8e89a6b9dacf5b570 Mon Sep 17 00:00:00 2001 From: micdoh Date: Thu, 16 Apr 2026 12:01:27 +0100 Subject: [PATCH] fix: expose alpha param through LigerFusedLinearDPO public API LigerFusedLinearDPOFunction.forward() accepted every base-class parameter except alpha, so the NLL scaling weight silently defaulted to 1.0 regardless of what callers passed. This adds alpha to both the Function and the LigerFusedLinearDPOLoss module, fixes the positional-arg order in the existing functional tests, and adds a regression test that verifies alpha actually affects the loss value when compute_nll_loss=True. --- src/liger_kernel/chunked_loss/dpo_loss.py | 13 +++-- test/chunked_loss/test_dpo_loss.py | 59 +++++++++++++++++++---- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index f7a14e539..633aebf43 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -108,6 +108,7 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, + alpha=1.0, compute_nll_loss=False, compiled=True, use_ref_model=True, @@ -126,7 +127,8 @@ def forward( ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) ignore_index (int): Index to ignore in loss computation - beta (float): Weight for the odds ratio loss + beta (float): Weight for the direct preference loss + alpha (float): Weight for the NLL loss component compute_nll_loss (bool): Whether to compute the NLL loss compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model @@ -144,6 +146,7 @@ def forward( bias=bias, ignore_index=ignore_index, beta=beta, + alpha=alpha, compute_nll_loss=compute_nll_loss, compiled=compiled, use_ref_model=use_ref_model, @@ -158,7 +161,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -170,6 +173,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, + alpha: float = 1.0, compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = True, @@ -180,7 +184,8 @@ def __init__( """ Args: ignore_index (int): Index to ignore in the loss. - beta (float): Weight for the odds ratio loss. + beta (float): Weight for the direct preference loss. + alpha (float): Weight for the NLL loss component. compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. @@ -190,6 +195,7 @@ def __init__( super().__init__() self.ignore_index = ignore_index self.beta = beta + self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model @@ -220,6 +226,7 @@ def forward( ref_bias, self.ignore_index, self.beta, + self.alpha, self.compute_nll_loss, self.compiled, self.use_ref_model, diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index de5762f26..2c39af796 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -643,8 +643,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( @@ -655,8 +656,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight2, ref_bias2, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, ) @@ -886,14 +888,15 @@ def test_correctness_functional_apo_loss_types( ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, - True, # compiled - True, # use_ref_model + True, # compiled + True, # use_ref_model False, # average_log_prob - 1, # chunk_size - loss_type, # loss_type + 1, # chunk_size + loss_type, ) # For comparison, create a LigerFusedLinearDPOLoss with the loss_type @@ -936,3 +939,39 @@ def test_invalid_loss_type(): # Should not raise an exception loss_fn = LigerFusedLinearDPOLoss(loss_type=loss_type) assert loss_fn.loss_type == loss_type + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_alpha_scales_nll_loss(dtype): + """ + Verify that alpha is actually forwarded and scales the NLL component. + With compute_nll_loss=True, loss(alpha=2) should differ from loss(alpha=1). + """ + B, T, H, V = 4, 16, 32, 64 + atol = 1e-4 if dtype == torch.float32 else 5e-2 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + + def run(alpha): + inp = _input.detach().clone().requires_grad_(True) + w = _weight.detach().clone().requires_grad_(True) + rw = _ref_weight.detach().clone().requires_grad_(True) + loss_fn = LigerFusedLinearDPOLoss( + beta=0.1, + alpha=alpha, + compute_nll_loss=True, + use_ref_model=True, + average_log_prob=False, + ) + loss, _ = loss_fn(w, inp, target, None, _input.detach(), rw, None) + return loss + + loss_alpha1 = run(alpha=1.0) + loss_alpha2 = run(alpha=2.0) + + assert not torch.allclose(loss_alpha1, loss_alpha2, atol=atol), ( + f"Expected losses to differ when alpha changes, but got {loss_alpha1} vs {loss_alpha2}" + )