Commit 28e6aca
authored
fix(utils): propagate non_blocking in TorchAOBaseTensor._to_copy and _get_to_kwargs (#4297)
* fix(utils): propagate non_blocking in TorchAOBaseTensor._to_copy and _get_to_kwargs
## Problem
`_get_to_kwargs` explicitly discarded the `non_blocking` argument parsed from
`torch._C._nn._parse_to`, with a comment saying it is "not very useful for
most tensor subclasses". As a result, any call to `tensor.to(device,
non_blocking=True)` on a `TorchAOBaseTensor` subclass silently became a
blocking transfer at the inner-tensor level.
This matters in practice for async CPU→GPU offloading workflows such as
`diffusers` `enable_group_offload(use_stream=True)`: the diffusers hook
schedules copies with `non_blocking=True` so that the transfer stream and
the compute stream can overlap. Because the flag was dropped, all copies
became blocking, negating the overlap benefit.
On AMD ROCm (gfx1xxx) the missing non_blocking also interacts with a
separate stream-ordering race (fixed in huggingface/diffusers#13502): the
default stream can race ahead of "blocking" copies that the OS scheduler
hasn't committed yet, producing device-mismatch errors in the first matmul.
## Fix
1. `_get_to_kwargs`: include `non_blocking` in the returned kwargs dict.
2. `TorchAOBaseTensor._to_copy.default`: pop `non_blocking` from kwargs and
forward it to every inner `.to()` call for both `tensor_data_names` and
`optional_tensor_data_names`.
The change is backward-compatible: when `non_blocking=False` (the default),
behaviour is identical to before.
## Tested on
- 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7
- FLUX.1-dev int8 (`Int8WeightOnlyConfig`) with `enable_group_offload(use_stream=True)`
- Companion fix in diffusers: huggingface/diffusers#13502
* test(utils): add non_blocking propagation test for _get_to_kwargs
Verifies the contract change in TorchAOBaseTensor._get_to_kwargs:
the returned kwargs dict now includes `non_blocking`, propagated
from the original `.to(device, non_blocking=...)` call.
Covers three cases: explicit True, explicit False, and default
(unspecified). Runs on CPU only, no @skip_if_no_cuda needed.
Addresses review feedback on PR #4297.1 parent 9052ece commit 28e6aca
2 files changed
Lines changed: 43 additions & 10 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
59 | 92 | | |
60 | 93 | | |
61 | 94 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
594 | 594 | | |
595 | 595 | | |
596 | 596 | | |
| 597 | + | |
597 | 598 | | |
598 | | - | |
| 599 | + | |
| 600 | + | |
599 | 601 | | |
600 | 602 | | |
601 | 603 | | |
602 | 604 | | |
603 | 605 | | |
604 | 606 | | |
605 | | - | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
606 | 610 | | |
607 | 611 | | |
608 | 612 | | |
| |||
693 | 697 | | |
694 | 698 | | |
695 | 699 | | |
696 | | - | |
697 | | - | |
| 700 | + | |
698 | 701 | | |
699 | | - | |
| 702 | + | |
700 | 703 | | |
701 | 704 | | |
702 | 705 | | |
703 | 706 | | |
704 | 707 | | |
705 | | - | |
706 | | - | |
707 | | - | |
708 | | - | |
709 | | - | |
| 708 | + | |
710 | 709 | | |
711 | 710 | | |
712 | 711 | | |
713 | 712 | | |
714 | 713 | | |
| 714 | + | |
715 | 715 | | |
716 | 716 | | |
717 | 717 | | |
| |||
0 commit comments