Skip to content

Commit e68c2f6

Browse files
authored
[Feature] Per-leaf device/dtype casting via an attrs tensordict (#1678)
1 parent 69cd230 commit e68c2f6

8 files changed

Lines changed: 456 additions & 0 deletions

File tree

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
"python": ("https://docs.python.org/3/", None),
171171
"torch": ("https://pytorch.org/docs/stable/", None),
172172
"numpy": ("https://numpy.org/doc/stable/", None),
173+
"pytorch_tutorials": ("https://docs.pytorch.org/tutorials/", None),
173174
}
174175

175176

docs/source/overview.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,42 @@ Suppose we have some function foo() -> TensorDict and that we do something like
445445
When ``i == 0`` the empty :class:`~tensordict.TensorDict` will automatically be populated with empty tensors with batch
446446
size N. In subsequent iterations of the loop the updates will all be written in-place.
447447

448+
Per-leaf device and dtype casting
449+
---------------------------------
450+
451+
A deviceless :class:`~tensordict.TensorDict` (``device=None``) may legitimately hold leaves on
452+
different devices or with different dtypes. The usual ``td.to(device)`` collapses every leaf onto a
453+
single device, which is not always what you want — for instance when aligning an arbitrary tensordict
454+
to the heterogeneous placement of another one (FSDP/TP shards, cross-rank exchanges, mixed-precision
455+
state dicts).
456+
457+
Calling :meth:`~tensordict.TensorDictBase.attrs` returns a lightweight tensordict whose leaves are
458+
:class:`~tensordict.TensorAttrs` — one per leaf — recording each source leaf's target device, dtype
459+
and shape. That tensordict can be passed directly to ``to()``, which then casts each leaf of the
460+
caller to the attributes of the matching leaf in the spec:
461+
462+
>>> from tensordict import TensorDict
463+
>>> td0 = TensorDict({"a": torch.zeros(3, device="cuda:0"),
464+
... "b": torch.zeros(3, device="cpu")}, batch_size=[3])
465+
>>> td2 = TensorDict({"a": torch.zeros(3, device="cpu"),
466+
... "b": torch.zeros(3, device="cuda:1")}, batch_size=[3])
467+
>>> td3 = td0.to(td2.attrs()) # td3["a"] is on cpu, td3["b"] is on cuda:1
468+
469+
Copies are issued asynchronously by default. A synchronization is performed at the end only when at
470+
least one leaf moved D2H (device to host) — the case where the host read is not coordinated by the
471+
source device's stream scheduler. H2D and cross-device D2D transfers do not need an explicit sync:
472+
subsequent kernels on the destination device already serialize on the stream that enqueued the copy,
473+
so CUDA (and other accelerator runtimes) handle the dependency for you. Pass ``non_blocking=True``
474+
to skip the sync entirely (caller takes responsibility), or ``non_blocking=False`` to make copies
475+
blocking. Dtype-only specs never trigger a sync. Leaves that do not appear in the spec tensordict
476+
are passed through unchanged, so a partial spec can be used to retarget a subset of leaves.
477+
:meth:`~tensordict.TensorDictBase.attrs` also accepts a ``fields`` argument
478+
(``("device", "dtype", "shape")`` by default) to limit what each :class:`~tensordict.TensorAttrs`
479+
records, which is useful when the caller only needs to match one dimension of the spec.
480+
481+
For background on the non-blocking / pinned-memory rules that motivate this D2H-only sync, see the
482+
PyTorch tutorial :external+pytorch_tutorials:doc:`A Guide on Good Usage of non_blocking and pin_memory() <intermediate/pinmem_nonblock>`.
483+
448484
TensorDictModule
449485
----------------
450486

docs/source/reference/tc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ Here is an example:
283283
NonTensorData
284284
MetaData
285285
NonTensorStack
286+
TensorAttrs
286287
UnbatchedTensor
287288
from_dataclass
288289

tensordict/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
NonTensorData,
5454
NonTensorDataBase,
5555
NonTensorStack,
56+
TensorAttrs,
5657
tensorclass,
5758
TensorClass,
5859
)
@@ -169,6 +170,7 @@
169170
"NonTensorData",
170171
"NonTensorDataBase",
171172
"NonTensorStack",
173+
"TensorAttrs",
172174
# NN imports
173175
"as_tensordict_module",
174176
"TensorClassModuleBase",

tensordict/base.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16178,6 +16178,21 @@ def to(self, *args, **kwargs) -> Self:
1617816178
>>> assert td["x"].device.type == "cpu"
1617916179
>>> assert td["y"].device.type == "cpu" # Output also restored to original device
1618016180
"""
16181+
# Per-leaf spec: a positional tensordict argument is interpreted as an
16182+
# attrs tensordict whose leaves are `TensorAttrs` — each source leaf
16183+
# is cast to its counterpart's device/dtype.
16184+
if (
16185+
args
16186+
and not isinstance(args[0], Tensor)
16187+
and _is_tensor_collection(type(args[0]))
16188+
):
16189+
attrs_td = args[0]
16190+
if len(args) > 1:
16191+
raise TypeError(
16192+
"to(attrs_td) does not accept additional positional arguments."
16193+
)
16194+
return self._to_per_leaf(attrs_td, **kwargs)
16195+
1618116196
non_blocking = kwargs.pop("non_blocking", None)
1618216197

1618316198
(
@@ -16277,6 +16292,148 @@ def get(name, val):
1627716292
self._sync_all()
1627816293
return result
1627916294

16295+
def attrs(
16296+
self,
16297+
*,
16298+
fields: Sequence[str] = ("device", "dtype", "shape"),
16299+
num_threads: int | None = None,
16300+
) -> Self:
16301+
"""Return a deviceless tensordict whose leaves are :class:`~tensordict.TensorAttrs`.
16302+
16303+
Each tensor leaf of ``self`` is replaced by a :class:`~tensordict.TensorAttrs`
16304+
describing the requested tensor attributes. The result is intended to be passed to
16305+
:meth:`to` to drive per-leaf device/dtype casting when the source tensordict
16306+
is heterogeneous.
16307+
16308+
Keyword Args:
16309+
fields (sequence of str, optional): which attributes to record on each
16310+
:class:`~tensordict.TensorAttrs`. Accepts any subset of
16311+
``("device", "dtype", "shape")``. Attributes not listed remain ``None``.
16312+
Defaults to ``("device", "dtype", "shape")``.
16313+
num_threads (int or None, optional): number of threads to use when
16314+
iterating leaves. Defaults to ``None`` (single-threaded). Construction
16315+
of :class:`TensorAttrs` is Python-bound, so threading typically yields
16316+
little; exposed for symmetry with :meth:`to`.
16317+
16318+
Examples:
16319+
>>> import torch
16320+
>>> from tensordict import TensorDict
16321+
>>> td = TensorDict(
16322+
... {"a": torch.zeros(3, device="cpu"),
16323+
... "b": torch.zeros(3, dtype=torch.int32)},
16324+
... batch_size=[3],
16325+
... )
16326+
>>> attrs = td.attrs(fields=("device", "dtype"))
16327+
>>> target = TensorDict(a=torch.zeros(3, device="cpu"), b=torch.zeros(3), batch_size=[3])
16328+
>>> out = target.to(attrs) # casts `b` to int32 per-leaf
16329+
>>> out["b"].dtype
16330+
torch.int32
16331+
"""
16332+
from tensordict.tensorclass import TensorAttrs
16333+
16334+
def _to_attrs(t):
16335+
return TensorAttrs.from_tensor(t, fields=fields)
16336+
16337+
return self._fast_apply(
16338+
_to_attrs,
16339+
batch_size=(),
16340+
device=None,
16341+
propagate_lock=False,
16342+
is_leaf=_NESTED_TENSORS_AS_LISTS,
16343+
num_threads=num_threads if num_threads is not None else 0,
16344+
)
16345+
16346+
def _to_per_leaf(
16347+
self,
16348+
attrs_td: TensorDictBase,
16349+
*,
16350+
non_blocking: bool | None = None,
16351+
non_blocking_pin: bool = False,
16352+
num_threads: int | None = None,
16353+
inplace: bool = False,
16354+
) -> Self:
16355+
"""Cast each leaf of ``self`` to the attributes recorded in ``attrs_td``.
16356+
16357+
Leaves absent from ``attrs_td`` (or whose attrs have both ``tgt_device=None``
16358+
and ``tgt_dtype=None``) are passed through unchanged.
16359+
16360+
When the caller does not pass ``non_blocking`` explicitly, per-leaf copies
16361+
are issued asynchronously. A single :meth:`_sync_all` is invoked at the end
16362+
only if at least one leaf went D2H (source on a CUDA/XPU/etc. device, target
16363+
on CPU). Async H2D copies do not need an explicit sync — subsequent kernels
16364+
on the destination device serialize on the same CUDA stream that queued the
16365+
copy, so the dependency is already honored. Only the D2H direction needs the
16366+
barrier because host reads are not coordinated by the device's stream scheduler.
16367+
"""
16368+
from tensordict.tensorclass import TensorAttrs
16369+
16370+
if non_blocking_pin:
16371+
raise NotImplementedError(
16372+
"non_blocking_pin is not yet supported when an attrs tensordict is passed to `to()`."
16373+
)
16374+
16375+
if non_blocking is None:
16376+
sub_non_blocking = True
16377+
do_sync = True
16378+
else:
16379+
sub_non_blocking = non_blocking
16380+
do_sync = not non_blocking
16381+
16382+
def _is_attrs_leaf(cls):
16383+
return issubclass(cls, TensorAttrs) or _default_is_leaf(cls)
16384+
16385+
spec: dict = {}
16386+
for key, val in attrs_td.items(
16387+
include_nested=True, leaves_only=True, is_leaf=_is_attrs_leaf
16388+
):
16389+
if isinstance(val, TensorAttrs):
16390+
spec[key] = val
16391+
16392+
# D2H (device-to-host) is the only transfer direction that needs an explicit
16393+
# sync after an async copy: host memory is outside the source device's stream
16394+
# scheduler, so reads after the copy call returns may observe stale data. H2D
16395+
# and cross-device D2D are fine — the destination's stream already serializes
16396+
# on the enqueued copy.
16397+
needs_d2h_sync = False
16398+
16399+
def _cast(name, tensor):
16400+
attrs = spec.get(name)
16401+
if attrs is None:
16402+
return tensor
16403+
target_device = attrs.tgt_device
16404+
target_dtype = attrs.tgt_dtype
16405+
if target_device is None and target_dtype is None:
16406+
return tensor
16407+
if (
16408+
target_device is not None
16409+
and target_device.type == "cpu"
16410+
and tensor.device.type != "cpu"
16411+
):
16412+
nonlocal needs_d2h_sync
16413+
needs_d2h_sync = True
16414+
return tensor.to(
16415+
device=target_device,
16416+
dtype=target_dtype,
16417+
non_blocking=sub_non_blocking,
16418+
)
16419+
16420+
result = self._fast_apply(
16421+
_cast,
16422+
named=True,
16423+
nested_keys=True,
16424+
is_leaf=_NESTED_TENSORS_AS_LISTS,
16425+
propagate_lock=True,
16426+
out=self if inplace else None,
16427+
checked=True,
16428+
device=None,
16429+
num_threads=num_threads if num_threads is not None else 0,
16430+
)
16431+
16432+
if needs_d2h_sync and do_sync:
16433+
self._sync_all()
16434+
16435+
return result
16436+
1628016437
def _to_consolidated(
1628116438
self, *, device, pin_memory, num_threads, non_blocking, inplace
1628216439
):

tensordict/tensorclass.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def __subclasscheck__(self, subclass):
353353
"atleast_1d",
354354
"atleast_2d",
355355
"atleast_3d",
356+
"attrs",
356357
"auto_batch_size_",
357358
"auto_device_",
358359
"bfloat16",
@@ -4600,6 +4601,54 @@ def _stack_non_tensor(
46004601
return NonTensorStack(*list_of_non_tensor, stack_dim=dim)
46014602

46024603

4604+
class TensorAttrs(TensorClass):
4605+
"""Per-leaf carrier describing a tensor's attributes — device, dtype, and shape.
4606+
4607+
Mirrors the `PyTorch Tensor Attributes <https://pytorch.org/docs/stable/tensor_attributes.html>`_
4608+
terminology. Instances are produced by :meth:`~tensordict.TensorDictBase.attrs`
4609+
— one per leaf of the source tensordict — and consumed by
4610+
:meth:`~tensordict.TensorDictBase.to` to drive per-leaf device/dtype casting
4611+
when the source tensordict is heterogeneous (e.g. ``device=None`` with leaves
4612+
on different devices).
4613+
4614+
The field names are prefixed with ``tgt_`` to avoid shadowing the tensordict
4615+
attributes ``device``/``dtype``/``shape``.
4616+
4617+
Fields populated via :meth:`from_tensor` are controlled by the ``fields`` argument
4618+
so the API scales as new attributes are added.
4619+
4620+
Examples:
4621+
>>> import torch
4622+
>>> from tensordict import TensorDict
4623+
>>> td = TensorDict({"a": torch.zeros(3, device="cpu"),
4624+
... "b": torch.zeros(3, dtype=torch.int32)}, batch_size=[3])
4625+
>>> attrs = td.attrs()
4626+
>>> attrs["a"].tgt_device, attrs["a"].tgt_dtype
4627+
(device(type='cpu'), torch.float32)
4628+
4629+
"""
4630+
4631+
tgt_device: Any = None
4632+
tgt_dtype: Any = None
4633+
tgt_shape: Any = None
4634+
4635+
@classmethod
4636+
def from_tensor(cls, tensor, *, fields=("device", "dtype", "shape")):
4637+
"""Build a :class:`TensorAttrs` from a tensor, populating only the requested fields.
4638+
4639+
``fields`` accepts the short names ``"device"``, ``"dtype"``, ``"shape"``. Unrequested
4640+
fields remain ``None``.
4641+
"""
4642+
kwargs = {}
4643+
if "device" in fields:
4644+
kwargs["tgt_device"] = tensor.device
4645+
if "dtype" in fields:
4646+
kwargs["tgt_dtype"] = tensor.dtype
4647+
if "shape" in fields:
4648+
kwargs["tgt_shape"] = torch.Size(tensor.shape)
4649+
return cls(batch_size=(), **kwargs)
4650+
4651+
46034652
# For __setitem__ and _update_at_ we don't pass a kwarg but use a global variable instead
46044653
_BREAK_ON_MEMMAP = True
46054654

tensordict/tensorclass.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,9 @@ class TensorClass:
15481548
@overload
15491549
def to(self, *, batch_size: torch.Size) -> Self: ...
15501550
def to(self, *args, **kwargs) -> Self: ...
1551+
def attrs(
1552+
self, *, fields: Sequence[str] = ("device", "dtype", "shape")
1553+
) -> Self: ...
15511554
def is_floating_point(self) -> bool: ...
15521555
def double(self) -> Self: ...
15531556
def float(self) -> Self: ...

0 commit comments

Comments
 (0)