Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pytorch_tutorials": ("https://docs.pytorch.org/tutorials/", None),
}


Expand Down
36 changes: 36 additions & 0 deletions docs/source/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,42 @@ Suppose we have some function foo() -> TensorDict and that we do something like
When ``i == 0`` the empty :class:`~tensordict.TensorDict` will automatically be populated with empty tensors with batch
size N. In subsequent iterations of the loop the updates will all be written in-place.

Per-leaf device and dtype casting
---------------------------------

A deviceless :class:`~tensordict.TensorDict` (``device=None``) may legitimately hold leaves on
different devices or with different dtypes. The usual ``td.to(device)`` collapses every leaf onto a
single device, which is not always what you want — for instance when aligning an arbitrary tensordict
to the heterogeneous placement of another one (FSDP/TP shards, cross-rank exchanges, mixed-precision
state dicts).

Calling :meth:`~tensordict.TensorDictBase.attrs` returns a lightweight tensordict whose leaves are
:class:`~tensordict.TensorAttrs` — one per leaf — recording each source leaf's target device, dtype
and shape. That tensordict can be passed directly to ``to()``, which then casts each leaf of the
caller to the attributes of the matching leaf in the spec:

>>> from tensordict import TensorDict
>>> td0 = TensorDict({"a": torch.zeros(3, device="cuda:0"),
... "b": torch.zeros(3, device="cpu")}, batch_size=[3])
>>> td2 = TensorDict({"a": torch.zeros(3, device="cpu"),
... "b": torch.zeros(3, device="cuda:1")}, batch_size=[3])
>>> td3 = td0.to(td2.attrs()) # td3["a"] is on cpu, td3["b"] is on cuda:1

Copies are issued asynchronously by default. A synchronization is performed at the end only when at
least one leaf moved D2H (device to host) — the case where the host read is not coordinated by the
source device's stream scheduler. H2D and cross-device D2D transfers do not need an explicit sync:
subsequent kernels on the destination device already serialize on the stream that enqueued the copy,
so CUDA (and other accelerator runtimes) handle the dependency for you. Pass ``non_blocking=True``
to skip the sync entirely (caller takes responsibility), or ``non_blocking=False`` to make copies
blocking. Dtype-only specs never trigger a sync. Leaves that do not appear in the spec tensordict
are passed through unchanged, so a partial spec can be used to retarget a subset of leaves.
:meth:`~tensordict.TensorDictBase.attrs` also accepts a ``fields`` argument
(``("device", "dtype", "shape")`` by default) to limit what each :class:`~tensordict.TensorAttrs`
records, which is useful when the caller only needs to match one dimension of the spec.

For background on the non-blocking / pinned-memory rules that motivate this D2H-only sync, see the
PyTorch tutorial :external+pytorch_tutorials:doc:`A Guide on Good Usage of non_blocking and pin_memory() <intermediate/pinmem_nonblock>`.

TensorDictModule
----------------

Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/tc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ Here is an example:
NonTensorData
MetaData
NonTensorStack
TensorAttrs
UnbatchedTensor
from_dataclass

Expand Down
2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NonTensorData,
NonTensorDataBase,
NonTensorStack,
TensorAttrs,
tensorclass,
TensorClass,
)
Expand Down Expand Up @@ -169,6 +170,7 @@
"NonTensorData",
"NonTensorDataBase",
"NonTensorStack",
"TensorAttrs",
# NN imports
"as_tensordict_module",
"TensorClassModuleBase",
Expand Down
157 changes: 157 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16118,6 +16118,21 @@ def to(self, *args, **kwargs) -> Self:
>>> assert td["x"].device.type == "cpu"
>>> assert td["y"].device.type == "cpu" # Output also restored to original device
"""
# Per-leaf spec: a positional tensordict argument is interpreted as an
# attrs tensordict whose leaves are `TensorAttrs` — each source leaf
# is cast to its counterpart's device/dtype.
if (
args
and not isinstance(args[0], Tensor)
and _is_tensor_collection(type(args[0]))
):
attrs_td = args[0]
if len(args) > 1:
raise TypeError(
"to(attrs_td) does not accept additional positional arguments."
)
return self._to_per_leaf(attrs_td, **kwargs)

non_blocking = kwargs.pop("non_blocking", None)

(
Expand Down Expand Up @@ -16217,6 +16232,148 @@ def get(name, val):
self._sync_all()
return result

def attrs(
self,
*,
fields: Sequence[str] = ("device", "dtype", "shape"),
num_threads: int | None = None,
) -> Self:
"""Return a deviceless tensordict whose leaves are :class:`~tensordict.TensorAttrs`.

Each tensor leaf of ``self`` is replaced by a :class:`~tensordict.TensorAttrs`
describing the requested tensor attributes. The result is intended to be passed to
:meth:`to` to drive per-leaf device/dtype casting when the source tensordict
is heterogeneous.

Keyword Args:
fields (sequence of str, optional): which attributes to record on each
:class:`~tensordict.TensorAttrs`. Accepts any subset of
``("device", "dtype", "shape")``. Attributes not listed remain ``None``.
Defaults to ``("device", "dtype", "shape")``.
num_threads (int or None, optional): number of threads to use when
iterating leaves. Defaults to ``None`` (single-threaded). Construction
of :class:`TensorAttrs` is Python-bound, so threading typically yields
little; exposed for symmetry with :meth:`to`.

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
... {"a": torch.zeros(3, device="cpu"),
... "b": torch.zeros(3, dtype=torch.int32)},
... batch_size=[3],
... )
>>> attrs = td.attrs(fields=("device", "dtype"))
>>> target = TensorDict(a=torch.zeros(3, device="cpu"), b=torch.zeros(3), batch_size=[3])
>>> out = target.to(attrs) # casts `b` to int32 per-leaf
>>> out["b"].dtype
torch.int32
"""
from tensordict.tensorclass import TensorAttrs

def _to_attrs(t):
return TensorAttrs.from_tensor(t, fields=fields)

return self._fast_apply(
_to_attrs,
batch_size=(),
device=None,
propagate_lock=False,
is_leaf=_NESTED_TENSORS_AS_LISTS,
num_threads=num_threads if num_threads is not None else 0,
)

def _to_per_leaf(
self,
attrs_td: TensorDictBase,
*,
non_blocking: bool | None = None,
non_blocking_pin: bool = False,
num_threads: int | None = None,
inplace: bool = False,
) -> Self:
"""Cast each leaf of ``self`` to the attributes recorded in ``attrs_td``.

Leaves absent from ``attrs_td`` (or whose attrs have both ``tgt_device=None``
and ``tgt_dtype=None``) are passed through unchanged.

When the caller does not pass ``non_blocking`` explicitly, per-leaf copies
are issued asynchronously. A single :meth:`_sync_all` is invoked at the end
only if at least one leaf went D2H (source on a CUDA/XPU/etc. device, target
on CPU). Async H2D copies do not need an explicit sync — subsequent kernels
on the destination device serialize on the same CUDA stream that queued the
copy, so the dependency is already honored. Only the D2H direction needs the
barrier because host reads are not coordinated by the device's stream scheduler.
"""
from tensordict.tensorclass import TensorAttrs

if non_blocking_pin:
raise NotImplementedError(
"non_blocking_pin is not yet supported when an attrs tensordict is passed to `to()`."
)

if non_blocking is None:
sub_non_blocking = True
do_sync = True
else:
sub_non_blocking = non_blocking
do_sync = not non_blocking

def _is_attrs_leaf(cls):
return issubclass(cls, TensorAttrs) or _default_is_leaf(cls)

spec: dict = {}
for key, val in attrs_td.items(
include_nested=True, leaves_only=True, is_leaf=_is_attrs_leaf
):
if isinstance(val, TensorAttrs):
spec[key] = val

# D2H (device-to-host) is the only transfer direction that needs an explicit
# sync after an async copy: host memory is outside the source device's stream
# scheduler, so reads after the copy call returns may observe stale data. H2D
# and cross-device D2D are fine — the destination's stream already serializes
# on the enqueued copy.
needs_d2h_sync = False

def _cast(name, tensor):
attrs = spec.get(name)
if attrs is None:
return tensor
target_device = attrs.tgt_device
target_dtype = attrs.tgt_dtype
if target_device is None and target_dtype is None:
return tensor
if (
target_device is not None
and target_device.type == "cpu"
and tensor.device.type != "cpu"
):
nonlocal needs_d2h_sync
needs_d2h_sync = True
return tensor.to(
device=target_device,
dtype=target_dtype,
non_blocking=sub_non_blocking,
)

result = self._fast_apply(
_cast,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
out=self if inplace else None,
checked=True,
device=None,
num_threads=num_threads if num_threads is not None else 0,
)

if needs_d2h_sync and do_sync:
self._sync_all()

return result

def _to_consolidated(
self, *, device, pin_memory, num_threads, non_blocking, inplace
):
Expand Down
49 changes: 49 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __subclasscheck__(self, subclass):
"atleast_1d",
"atleast_2d",
"atleast_3d",
"attrs",
"auto_batch_size_",
"auto_device_",
"bfloat16",
Expand Down Expand Up @@ -4543,6 +4544,54 @@ def _stack_non_tensor(
return NonTensorStack(*list_of_non_tensor, stack_dim=dim)


class TensorAttrs(TensorClass):
"""Per-leaf carrier describing a tensor's attributes — device, dtype, and shape.

Mirrors the `PyTorch Tensor Attributes <https://pytorch.org/docs/stable/tensor_attributes.html>`_
terminology. Instances are produced by :meth:`~tensordict.TensorDictBase.attrs`
— one per leaf of the source tensordict — and consumed by
:meth:`~tensordict.TensorDictBase.to` to drive per-leaf device/dtype casting
when the source tensordict is heterogeneous (e.g. ``device=None`` with leaves
on different devices).

The field names are prefixed with ``tgt_`` to avoid shadowing the tensordict
attributes ``device``/``dtype``/``shape``.

Fields populated via :meth:`from_tensor` are controlled by the ``fields`` argument
so the API scales as new attributes are added.

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict({"a": torch.zeros(3, device="cpu"),
... "b": torch.zeros(3, dtype=torch.int32)}, batch_size=[3])
>>> attrs = td.attrs()
>>> attrs["a"].tgt_device, attrs["a"].tgt_dtype
(device(type='cpu'), torch.float32)

"""

tgt_device: Any = None
tgt_dtype: Any = None
tgt_shape: Any = None

@classmethod
def from_tensor(cls, tensor, *, fields=("device", "dtype", "shape")):
"""Build a :class:`TensorAttrs` from a tensor, populating only the requested fields.

``fields`` accepts the short names ``"device"``, ``"dtype"``, ``"shape"``. Unrequested
fields remain ``None``.
"""
kwargs = {}
if "device" in fields:
kwargs["tgt_device"] = tensor.device
if "dtype" in fields:
kwargs["tgt_dtype"] = tensor.dtype
if "shape" in fields:
kwargs["tgt_shape"] = torch.Size(tensor.shape)
return cls(batch_size=(), **kwargs)


# For __setitem__ and _update_at_ we don't pass a kwarg but use a global variable instead
_BREAK_ON_MEMMAP = True

Expand Down
3 changes: 3 additions & 0 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,9 @@ class TensorClass:
@overload
def to(self, *, batch_size: torch.Size) -> Self: ...
def to(self, *args, **kwargs) -> Self: ...
def attrs(
self, *, fields: Sequence[str] = ("device", "dtype", "shape")
) -> Self: ...
def is_floating_point(self) -> bool: ...
def double(self) -> Self: ...
def float(self) -> Self: ...
Expand Down
Loading
Loading