diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index f7a14e539..633aebf43 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -108,6 +108,7 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, + alpha=1.0, compute_nll_loss=False, compiled=True, use_ref_model=True, @@ -126,7 +127,8 @@ def forward( ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) ignore_index (int): Index to ignore in loss computation - beta (float): Weight for the odds ratio loss + beta (float): Weight for the direct preference loss + alpha (float): Weight for the NLL loss component compute_nll_loss (bool): Whether to compute the NLL loss compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model @@ -144,6 +146,7 @@ def forward( bias=bias, ignore_index=ignore_index, beta=beta, + alpha=alpha, compute_nll_loss=compute_nll_loss, compiled=compiled, use_ref_model=use_ref_model, @@ -158,7 +161,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -170,6 +173,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, + alpha: float = 1.0, compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = True, @@ -180,7 +184,8 @@ def __init__( """ Args: ignore_index (int): Index to ignore in the loss. - beta (float): Weight for the odds ratio loss. + beta (float): Weight for the direct preference loss. + alpha (float): Weight for the NLL loss component. compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. @@ -190,6 +195,7 @@ def __init__( super().__init__() self.ignore_index = ignore_index self.beta = beta + self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model @@ -220,6 +226,7 @@ def forward( ref_bias, self.ignore_index, self.beta, + self.alpha, self.compute_nll_loss, self.compiled, self.use_ref_model, diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index de5762f26..2c39af796 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -643,8 +643,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( @@ -655,8 +656,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_input, ref_weight2, ref_bias2, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, ) @@ -886,14 +888,15 @@ def test_correctness_functional_apo_loss_types( ref_input, ref_weight1, ref_bias1, - -100, - 0.1, + -100, # ignore_index + 0.1, # beta + 1.0, # alpha compute_nll_loss, - True, # compiled - True, # use_ref_model + True, # compiled + True, # use_ref_model False, # average_log_prob - 1, # chunk_size - loss_type, # loss_type + 1, # chunk_size + loss_type, ) # For comparison, create a LigerFusedLinearDPOLoss with the loss_type @@ -936,3 +939,39 @@ def test_invalid_loss_type(): # Should not raise an exception loss_fn = LigerFusedLinearDPOLoss(loss_type=loss_type) assert loss_fn.loss_type == loss_type + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_alpha_scales_nll_loss(dtype): + """ + Verify that alpha is actually forwarded and scales the NLL component. + With compute_nll_loss=True, loss(alpha=2) should differ from loss(alpha=1). + """ + B, T, H, V = 4, 16, 32, 64 + atol = 1e-4 if dtype == torch.float32 else 5e-2 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + + def run(alpha): + inp = _input.detach().clone().requires_grad_(True) + w = _weight.detach().clone().requires_grad_(True) + rw = _ref_weight.detach().clone().requires_grad_(True) + loss_fn = LigerFusedLinearDPOLoss( + beta=0.1, + alpha=alpha, + compute_nll_loss=True, + use_ref_model=True, + average_log_prob=False, + ) + loss, _ = loss_fn(w, inp, target, None, _input.detach(), rw, None) + return loss + + loss_alpha1 = run(alpha=1.0) + loss_alpha2 = run(alpha=2.0) + + assert not torch.allclose(loss_alpha1, loss_alpha2, atol=atol), ( + f"Expected losses to differ when alpha changes, but got {loss_alpha1} vs {loss_alpha2}" + )