Skip to content

Commit 28e6aca

Browse files
authored
fix(utils): propagate non_blocking in TorchAOBaseTensor._to_copy and _get_to_kwargs (#4297)
* fix(utils): propagate non_blocking in TorchAOBaseTensor._to_copy and _get_to_kwargs ## Problem `_get_to_kwargs` explicitly discarded the `non_blocking` argument parsed from `torch._C._nn._parse_to`, with a comment saying it is "not very useful for most tensor subclasses". As a result, any call to `tensor.to(device, non_blocking=True)` on a `TorchAOBaseTensor` subclass silently became a blocking transfer at the inner-tensor level. This matters in practice for async CPU→GPU offloading workflows such as `diffusers` `enable_group_offload(use_stream=True)`: the diffusers hook schedules copies with `non_blocking=True` so that the transfer stream and the compute stream can overlap. Because the flag was dropped, all copies became blocking, negating the overlap benefit. On AMD ROCm (gfx1xxx) the missing non_blocking also interacts with a separate stream-ordering race (fixed in huggingface/diffusers#13502): the default stream can race ahead of "blocking" copies that the OS scheduler hasn't committed yet, producing device-mismatch errors in the first matmul. ## Fix 1. `_get_to_kwargs`: include `non_blocking` in the returned kwargs dict. 2. `TorchAOBaseTensor._to_copy.default`: pop `non_blocking` from kwargs and forward it to every inner `.to()` call for both `tensor_data_names` and `optional_tensor_data_names`. The change is backward-compatible: when `non_blocking=False` (the default), behaviour is identical to before. ## Tested on - 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7 - FLUX.1-dev int8 (`Int8WeightOnlyConfig`) with `enable_group_offload(use_stream=True)` - Companion fix in diffusers: huggingface/diffusers#13502 * test(utils): add non_blocking propagation test for _get_to_kwargs Verifies the contract change in TorchAOBaseTensor._get_to_kwargs: the returned kwargs dict now includes `non_blocking`, propagated from the original `.to(device, non_blocking=...)` call. Covers three cases: explicit True, explicit False, and default (unspecified). Runs on CPU only, no @skip_if_no_cuda needed. Addresses review feedback on PR #4297.
1 parent 9052ece commit 28e6aca

2 files changed

Lines changed: 43 additions & 10 deletions

File tree

test/test_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,39 @@ def __init__(self, data):
5656
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
5757
l.weight = torch.nn.Parameter(MyTensor(l.weight))
5858

59+
def test_get_to_kwargs_non_blocking(self):
60+
"""Verify _get_to_kwargs parses and returns the non_blocking flag."""
61+
62+
class MyTensor(TorchAOBaseTensor):
63+
tensor_data_names = ["qdata"]
64+
tensor_attribute_names = ["attr", "device"]
65+
66+
def __new__(cls, qdata, attr="attr", device=None):
67+
if device is None:
68+
device = qdata.device
69+
kwargs = {"device": device, "dtype": qdata.dtype}
70+
r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs)
71+
r.qdata = qdata
72+
r.attr = attr
73+
return r
74+
75+
def __init__(self, qdata, attr="attr", device=None):
76+
pass
77+
78+
t = MyTensor(torch.randn(4, 4))
79+
80+
# non_blocking=True is preserved
81+
kwargs = t._get_to_kwargs(device="cpu", non_blocking=True)
82+
self.assertTrue(kwargs["non_blocking"])
83+
84+
# non_blocking=False (explicit) is preserved
85+
kwargs = t._get_to_kwargs(device="cpu", non_blocking=False)
86+
self.assertFalse(kwargs["non_blocking"])
87+
88+
# default (not specified) → False
89+
kwargs = t._get_to_kwargs(device="cpu")
90+
self.assertFalse(kwargs["non_blocking"])
91+
5992
def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
6093
# get `all_tensor_data_names` and `all_tensor_attribute_names`
6194
all_tensor_data_names = lp_tensor.tensor_data_names.copy()

torchao/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -594,15 +594,19 @@ def _(func, types, args, kwargs):
594594
):
595595
kwargs = self._get_to_kwargs(*args[1:], **kwargs)
596596
device = kwargs.pop("device")
597+
non_blocking = kwargs.pop("non_blocking", False)
597598
tensors = [
598-
getattr(self, name).to(device) for name in self.tensor_data_names
599+
getattr(self, name).to(device, non_blocking=non_blocking)
600+
for name in self.tensor_data_names
599601
]
600602
optional_tensors = []
601603
if hasattr(self, "optional_tensor_data_names"):
602604
for tensor_data_name in self.optional_tensor_data_names:
603605
maybe_tensor = getattr(self, tensor_data_name)
604606
if maybe_tensor is not None:
605-
optional_tensors.append(maybe_tensor.to(device))
607+
optional_tensors.append(
608+
maybe_tensor.to(device, non_blocking=non_blocking)
609+
)
606610
else:
607611
optional_tensors.append(None)
608612

@@ -693,25 +697,21 @@ class MyTensor(torch.Tensor):
693697

694698

695699
def _get_to_kwargs(self, *args, **kwargs):
696-
"""Helper function to get the device and dtype keyword args for `aten._to_copy.default` op
697-
only device and dtype are kept
700+
"""Helper function to get the device, dtype and non_blocking keyword args for `aten._to_copy.default` op
698701
699-
Returns: {"device": device, "dtype": dtype}
702+
Returns: {"device": device, "dtype": dtype, "non_blocking": non_blocking}
700703
"""
701704
# `torch._C._nn._parse_to` can't handle `layout` argument
702705
args = tuple(arg for arg in args if not isinstance(arg, torch.layout))
703706
if "layout" in kwargs:
704707
kwargs.pop("layout")
705-
# ignoring `non_blocking` and `memory_format` args since these are not
706-
# very useful for most of the tensor subclasses
707-
# if in the future there are use cases that need these, we'd recommend
708-
# to override `_get_to_kwargs` and return these args
709-
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
708+
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
710709
device = self.device if device is None else device
711710
dtype = self.dtype if dtype is None else dtype
712711
kwargs = {
713712
"device": device,
714713
"dtype": dtype,
714+
"non_blocking": non_blocking,
715715
}
716716
return kwargs
717717

0 commit comments

Comments
 (0)