diff --git a/auto_round/data_type/nvfp.py b/auto_round/data_type/nvfp.py index 60cee640a..f14b0a423 100644 --- a/auto_round/data_type/nvfp.py +++ b/auto_round/data_type/nvfp.py @@ -184,10 +184,18 @@ def cast_to_ue5m3(tensor): return res -def cast_to_ue5m3_ste(x): - fp4 = (cast_to_ue5m3(x).to(x.dtype) - x).detach() + x +class _UE5M3CastSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return cast_to_ue5m3(x).to(x.dtype) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output - return fp4 + +def cast_to_ue5m3_ste(x): + return _UE5M3CastSTE.apply(x) def ref_fp4_quant(x, global_scale, block_size=16, v=0, max_scale=1.0): diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index bd4d74eaa..37414a809 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -181,7 +181,7 @@ def round_ste(x: torch.Tensor): Returns: torch.Tensor """ - return (x.round() - x).detach() + x + return _RoundSTE.apply(x) def floor_ste(x: torch.Tensor): @@ -193,7 +193,7 @@ def floor_ste(x: torch.Tensor): Returns: torch.Tensor """ - return (x.floor() - x).detach() + x + return _FloorSTE.apply(x) def ceil_ste(x: torch.Tensor): @@ -205,7 +205,57 @@ def ceil_ste(x: torch.Tensor): Returns: torch.Tensor """ - return (x.ceil() - x).detach() + x + return _CeilSTE.apply(x) + + +class _RoundSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return torch.round(x) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output + + +class _FloorSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return torch.floor(x) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output + + +class _CeilSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return torch.ceil(x) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output + + +class _Float8CastSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, fp8_dtype): + return x.to(fp8_dtype).to(x.dtype) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output, None + + +class _HpuFloat8CastSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, fp8_dtype): + return torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, fp8_dtype)[0].to(x.dtype) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + return grad_output, None @torch._dynamo.disable() @@ -221,9 +271,7 @@ def float8_e4m3fn_ste(x: torch.Tensor): Returns: torch.Tensor: Quantized and dequantized tensor using float8 format. """ - fp8 = (x.to(torch.float8_e4m3fn).to(x.dtype) - x).detach() + x - - return fp8 + return _Float8CastSTE.apply(x, torch.float8_e4m3fn) def float8_e5m2_ste(x: torch.Tensor): @@ -238,9 +286,7 @@ def float8_e5m2_ste(x: torch.Tensor): Returns: torch.Tensor: Quantized and dequantized tensor using float8 format. """ - fp8 = (x.to(torch.float8_e5m2).to(x.dtype) - x).detach() + x - - return fp8 + return _Float8CastSTE.apply(x, torch.float8_e5m2) def float8_e4m3fn_hpu_ste(x: torch.Tensor): @@ -255,9 +301,7 @@ def float8_e4m3fn_hpu_ste(x: torch.Tensor): Returns: torch.Tensor: Quantized and dequantized tensor using float8 format. """ - fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x - - return fp8 + return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn) def float8_e4m3fnuz_hpu_ste(x: torch.Tensor): @@ -272,8 +316,7 @@ def float8_e4m3fnuz_hpu_ste(x: torch.Tensor): Returns: torch.Tensor: Quantized and dequantized tensor using float8 format. """ - fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x - return fp8 + return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn) @lru_cache(None)