Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def forward(
ref_bias=None,
ignore_index=-100,
beta=0.1,
alpha=1.0,
compute_nll_loss=False,
compiled=True,
use_ref_model=True,
Expand All @@ -126,7 +127,8 @@ def forward(
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
ignore_index (int): Index to ignore in loss computation
beta (float): Weight for the odds ratio loss
beta (float): Weight for the direct preference loss
alpha (float): Weight for the NLL loss component
compute_nll_loss (bool): Whether to compute the NLL loss
compiled (bool): Whether to use torch compile
use_ref_model (bool): Whether to use a reference model
Expand All @@ -144,6 +146,7 @@ def forward(
bias=bias,
ignore_index=ignore_index,
beta=beta,
alpha=alpha,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
use_ref_model=use_ref_model,
Expand All @@ -158,7 +161,7 @@ def forward(
@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand All @@ -170,6 +173,7 @@ def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
compute_nll_loss: bool = False,
compiled: bool = True,
use_ref_model: bool = True,
Expand All @@ -180,7 +184,8 @@ def __init__(
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
beta (float): Weight for the direct preference loss.
alpha (float): Weight for the NLL loss component.
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
Expand All @@ -190,6 +195,7 @@ def __init__(
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model
Expand Down Expand Up @@ -220,6 +226,7 @@ def forward(
ref_bias,
self.ignore_index,
self.beta,
self.alpha,
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
Expand Down
59 changes: 49 additions & 10 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
ref_input,
ref_weight1,
ref_bias1,
-100,
0.1,
-100, # ignore_index
0.1, # beta
1.0, # alpha
compute_nll_loss,
)
loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo(
Expand All @@ -655,8 +656,9 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
ref_input,
ref_weight2,
ref_bias2,
-100,
0.1,
-100, # ignore_index
0.1, # beta
1.0, # alpha
compute_nll_loss,
)

Expand Down Expand Up @@ -886,14 +888,15 @@ def test_correctness_functional_apo_loss_types(
ref_input,
ref_weight1,
ref_bias1,
-100,
0.1,
-100, # ignore_index
0.1, # beta
1.0, # alpha
compute_nll_loss,
True, # compiled
True, # use_ref_model
True, # compiled
True, # use_ref_model
False, # average_log_prob
1, # chunk_size
loss_type, # loss_type
1, # chunk_size
loss_type,
)

# For comparison, create a LigerFusedLinearDPOLoss with the loss_type
Expand Down Expand Up @@ -936,3 +939,39 @@ def test_invalid_loss_type():
# Should not raise an exception
loss_fn = LigerFusedLinearDPOLoss(loss_type=loss_type)
assert loss_fn.loss_type == loss_type


@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_alpha_scales_nll_loss(dtype):
"""
Verify that alpha is actually forwarded and scales the NLL component.
With compute_nll_loss=True, loss(alpha=2) should differ from loss(alpha=1).
"""
B, T, H, V = 4, 16, 32, 64
atol = 1e-4 if dtype == torch.float32 else 5e-2

_weight = torch.randn(V, H, device=device, dtype=dtype)
_ref_weight = torch.randn(V, H, device=device, dtype=dtype)
_input = torch.randn(B, T, H, device=device, dtype=dtype)
target = torch.randint(0, V, (B, T), device=device, dtype=torch.long)

def run(alpha):
inp = _input.detach().clone().requires_grad_(True)
w = _weight.detach().clone().requires_grad_(True)
rw = _ref_weight.detach().clone().requires_grad_(True)
loss_fn = LigerFusedLinearDPOLoss(
beta=0.1,
alpha=alpha,
compute_nll_loss=True,
use_ref_model=True,
average_log_prob=False,
)
loss, _ = loss_fn(w, inp, target, None, _input.detach(), rw, None)
return loss

loss_alpha1 = run(alpha=1.0)
loss_alpha2 = run(alpha=2.0)

assert not torch.allclose(loss_alpha1, loss_alpha2, atol=atol), (
f"Expected losses to differ when alpha changes, but got {loss_alpha1} vs {loss_alpha2}"
)