Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,18 @@ def cast_to_ue5m3(tensor):
return res


def cast_to_ue5m3_ste(x):
fp4 = (cast_to_ue5m3(x).to(x.dtype) - x).detach() + x
class _UE5M3CastSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return cast_to_ue5m3(x).to(x.dtype)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output

return fp4

def cast_to_ue5m3_ste(x):
return _UE5M3CastSTE.apply(x)


Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast_to_ue5m3_ste now relies on a custom torch.autograd.Function, but there are no tests asserting the straight-through gradient behavior for this path. Consider adding a unit test that checks forward equivalence to cast_to_ue5m3 and that gradients through cast_to_ue5m3_ste are identity (within expected dtype tolerances).

Suggested change
def _validate_cast_to_ue5m3_ste():
"""Validate the STE contract for cast_to_ue5m3_ste.
This helper is intentionally kept local to the implementation so tests can
exercise the same forward and backward path used in production:
* forward output matches ``cast_to_ue5m3``
* backward pass behaves as an identity straight-through estimator
"""
test_specs = {
torch.float16: {"rtol": 1e-3, "atol": 1e-3},
torch.float32: {"rtol": 1e-6, "atol": 1e-6},
torch.bfloat16: {"rtol": 1e-2, "atol": 1e-2},
}
base = torch.tensor(
[-480.0, -31.5, -2.75, -0.5, 0.0, 0.375, 1.0, 7.5, 96.0, 57344.0],
dtype=torch.float32,
)
for dtype, tol in test_specs.items():
x = base.to(dtype).clone().detach().requires_grad_(True)
y_ref = cast_to_ue5m3(x.detach())
y_ste = cast_to_ue5m3_ste(x)
torch.testing.assert_close(y_ste.detach(), y_ref, **tol)
upstream_grad = torch.linspace(-1.0, 1.0, steps=x.numel(), dtype=torch.float32).to(dtype)
y_ste.backward(upstream_grad)
torch.testing.assert_close(x.grad, upstream_grad, **tol)

Copilot uses AI. Check for mistakes.
def ref_fp4_quant(x, global_scale, block_size=16, v=0, max_scale=1.0):
Expand Down
71 changes: 57 additions & 14 deletions auto_round/data_type/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def round_ste(x: torch.Tensor):
Returns:
torch.Tensor
"""
return (x.round() - x).detach() + x
return _RoundSTE.apply(x)


def floor_ste(x: torch.Tensor):
Expand All @@ -193,7 +193,7 @@ def floor_ste(x: torch.Tensor):
Returns:
torch.Tensor
"""
return (x.floor() - x).detach() + x
return _FloorSTE.apply(x)


def ceil_ste(x: torch.Tensor):
Expand All @@ -205,7 +205,57 @@ def ceil_ste(x: torch.Tensor):
Returns:
torch.Tensor
"""
return (x.ceil() - x).detach() + x
return _CeilSTE.apply(x)


class _RoundSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return torch.round(x)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output
Comment on lines +211 to +218
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These STE helpers are now implemented via custom torch.autograd.Function, but there are no unit tests covering their backward semantics (e.g., verifying that round_ste/floor_ste/ceil_ste propagate gradients as identity and that the forward matches the corresponding rounding op). Adding small CPU tests for forward + torch.autograd.grad would help prevent silent regressions in quantization/tuning behavior.

Copilot uses AI. Check for mistakes.


class _FloorSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return torch.floor(x)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output


class _CeilSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return torch.ceil(x)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output


class _Float8CastSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, fp8_dtype):
return x.to(fp8_dtype).to(x.dtype)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output, None


class _HpuFloat8CastSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, fp8_dtype):
return torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, fp8_dtype)[0].to(x.dtype)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output, None


@torch._dynamo.disable()
Expand All @@ -221,9 +271,7 @@ def float8_e4m3fn_ste(x: torch.Tensor):
Returns:
torch.Tensor: Quantized and dequantized tensor using float8 format.
"""
fp8 = (x.to(torch.float8_e4m3fn).to(x.dtype) - x).detach() + x

return fp8
return _Float8CastSTE.apply(x, torch.float8_e4m3fn)


def float8_e5m2_ste(x: torch.Tensor):
Expand All @@ -238,9 +286,7 @@ def float8_e5m2_ste(x: torch.Tensor):
Returns:
torch.Tensor: Quantized and dequantized tensor using float8 format.
"""
fp8 = (x.to(torch.float8_e5m2).to(x.dtype) - x).detach() + x

return fp8
return _Float8CastSTE.apply(x, torch.float8_e5m2)


def float8_e4m3fn_hpu_ste(x: torch.Tensor):
Expand All @@ -255,9 +301,7 @@ def float8_e4m3fn_hpu_ste(x: torch.Tensor):
Returns:
torch.Tensor: Quantized and dequantized tensor using float8 format.
"""
fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x

return fp8
return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn)


def float8_e4m3fnuz_hpu_ste(x: torch.Tensor):
Expand All @@ -272,8 +316,7 @@ def float8_e4m3fnuz_hpu_ste(x: torch.Tensor):
Returns:
torch.Tensor: Quantized and dequantized tensor using float8 format.
"""
fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x
return fp8
return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_e4m3fnuz_hpu_ste is currently casting with torch.float8_e4m3fn, which conflicts with the function name and with the Gaudi2 FP8 path (which clips/scales using torch.float8_e4m3fnuz). This will change quantization behavior/range for the fnuz path. Use torch.float8_e4m3fnuz here to keep the dtype flavor consistent end-to-end.

Suggested change
return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn)
return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fnuz)

Copilot uses AI. Check for mistakes.


@lru_cache(None)
Expand Down
Loading