-
Notifications
You must be signed in to change notification settings - Fork 136
Fix QDQ inference OOM issue. #1763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -181,7 +181,7 @@ def round_ste(x: torch.Tensor): | |||||
| Returns: | ||||||
| torch.Tensor | ||||||
| """ | ||||||
| return (x.round() - x).detach() + x | ||||||
| return _RoundSTE.apply(x) | ||||||
|
|
||||||
|
|
||||||
| def floor_ste(x: torch.Tensor): | ||||||
|
|
@@ -193,7 +193,7 @@ def floor_ste(x: torch.Tensor): | |||||
| Returns: | ||||||
| torch.Tensor | ||||||
| """ | ||||||
| return (x.floor() - x).detach() + x | ||||||
| return _FloorSTE.apply(x) | ||||||
|
|
||||||
|
|
||||||
| def ceil_ste(x: torch.Tensor): | ||||||
|
|
@@ -205,7 +205,57 @@ def ceil_ste(x: torch.Tensor): | |||||
| Returns: | ||||||
| torch.Tensor | ||||||
| """ | ||||||
| return (x.ceil() - x).detach() + x | ||||||
| return _CeilSTE.apply(x) | ||||||
|
|
||||||
|
|
||||||
| class _RoundSTE(torch.autograd.Function): | ||||||
| @staticmethod | ||||||
| def forward(ctx, x: torch.Tensor): | ||||||
| return torch.round(x) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def backward(ctx, grad_output: torch.Tensor): | ||||||
| return grad_output | ||||||
|
Comment on lines
+211
to
+218
|
||||||
|
|
||||||
|
|
||||||
| class _FloorSTE(torch.autograd.Function): | ||||||
| @staticmethod | ||||||
| def forward(ctx, x: torch.Tensor): | ||||||
| return torch.floor(x) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def backward(ctx, grad_output: torch.Tensor): | ||||||
| return grad_output | ||||||
|
|
||||||
|
|
||||||
| class _CeilSTE(torch.autograd.Function): | ||||||
| @staticmethod | ||||||
| def forward(ctx, x: torch.Tensor): | ||||||
| return torch.ceil(x) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def backward(ctx, grad_output: torch.Tensor): | ||||||
| return grad_output | ||||||
|
|
||||||
|
|
||||||
| class _Float8CastSTE(torch.autograd.Function): | ||||||
| @staticmethod | ||||||
| def forward(ctx, x: torch.Tensor, fp8_dtype): | ||||||
| return x.to(fp8_dtype).to(x.dtype) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def backward(ctx, grad_output: torch.Tensor): | ||||||
| return grad_output, None | ||||||
|
|
||||||
|
|
||||||
| class _HpuFloat8CastSTE(torch.autograd.Function): | ||||||
| @staticmethod | ||||||
| def forward(ctx, x: torch.Tensor, fp8_dtype): | ||||||
| return torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, fp8_dtype)[0].to(x.dtype) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def backward(ctx, grad_output: torch.Tensor): | ||||||
| return grad_output, None | ||||||
|
|
||||||
|
|
||||||
| @torch._dynamo.disable() | ||||||
|
|
@@ -221,9 +271,7 @@ def float8_e4m3fn_ste(x: torch.Tensor): | |||||
| Returns: | ||||||
| torch.Tensor: Quantized and dequantized tensor using float8 format. | ||||||
| """ | ||||||
| fp8 = (x.to(torch.float8_e4m3fn).to(x.dtype) - x).detach() + x | ||||||
|
|
||||||
| return fp8 | ||||||
| return _Float8CastSTE.apply(x, torch.float8_e4m3fn) | ||||||
|
|
||||||
|
|
||||||
| def float8_e5m2_ste(x: torch.Tensor): | ||||||
|
|
@@ -238,9 +286,7 @@ def float8_e5m2_ste(x: torch.Tensor): | |||||
| Returns: | ||||||
| torch.Tensor: Quantized and dequantized tensor using float8 format. | ||||||
| """ | ||||||
| fp8 = (x.to(torch.float8_e5m2).to(x.dtype) - x).detach() + x | ||||||
|
|
||||||
| return fp8 | ||||||
| return _Float8CastSTE.apply(x, torch.float8_e5m2) | ||||||
|
|
||||||
|
|
||||||
| def float8_e4m3fn_hpu_ste(x: torch.Tensor): | ||||||
|
|
@@ -255,9 +301,7 @@ def float8_e4m3fn_hpu_ste(x: torch.Tensor): | |||||
| Returns: | ||||||
| torch.Tensor: Quantized and dequantized tensor using float8 format. | ||||||
| """ | ||||||
| fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x | ||||||
|
|
||||||
| return fp8 | ||||||
| return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn) | ||||||
|
|
||||||
|
|
||||||
| def float8_e4m3fnuz_hpu_ste(x: torch.Tensor): | ||||||
|
|
@@ -272,8 +316,7 @@ def float8_e4m3fnuz_hpu_ste(x: torch.Tensor): | |||||
| Returns: | ||||||
| torch.Tensor: Quantized and dequantized tensor using float8 format. | ||||||
| """ | ||||||
| fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x | ||||||
| return fp8 | ||||||
| return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn) | ||||||
|
||||||
| return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn) | |
| return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fnuz) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cast_to_ue5m3_stenow relies on a customtorch.autograd.Function, but there are no tests asserting the straight-through gradient behavior for this path. Consider adding a unit test that checks forward equivalence tocast_to_ue5m3and that gradients throughcast_to_ue5m3_steare identity (within expected dtype tolerances).