@@ -16178,6 +16178,21 @@ def to(self, *args, **kwargs) -> Self:
1617816178 >>> assert td["x"].device.type == "cpu"
1617916179 >>> assert td["y"].device.type == "cpu" # Output also restored to original device
1618016180 """
16181+ # Per-leaf spec: a positional tensordict argument is interpreted as an
16182+ # attrs tensordict whose leaves are `TensorAttrs` — each source leaf
16183+ # is cast to its counterpart's device/dtype.
16184+ if (
16185+ args
16186+ and not isinstance(args[0], Tensor)
16187+ and _is_tensor_collection(type(args[0]))
16188+ ):
16189+ attrs_td = args[0]
16190+ if len(args) > 1:
16191+ raise TypeError(
16192+ "to(attrs_td) does not accept additional positional arguments."
16193+ )
16194+ return self._to_per_leaf(attrs_td, **kwargs)
16195+
1618116196 non_blocking = kwargs.pop("non_blocking", None)
1618216197
1618316198 (
@@ -16277,6 +16292,148 @@ def get(name, val):
1627716292 self._sync_all()
1627816293 return result
1627916294
16295+ def attrs(
16296+ self,
16297+ *,
16298+ fields: Sequence[str] = ("device", "dtype", "shape"),
16299+ num_threads: int | None = None,
16300+ ) -> Self:
16301+ """Return a deviceless tensordict whose leaves are :class:`~tensordict.TensorAttrs`.
16302+
16303+ Each tensor leaf of ``self`` is replaced by a :class:`~tensordict.TensorAttrs`
16304+ describing the requested tensor attributes. The result is intended to be passed to
16305+ :meth:`to` to drive per-leaf device/dtype casting when the source tensordict
16306+ is heterogeneous.
16307+
16308+ Keyword Args:
16309+ fields (sequence of str, optional): which attributes to record on each
16310+ :class:`~tensordict.TensorAttrs`. Accepts any subset of
16311+ ``("device", "dtype", "shape")``. Attributes not listed remain ``None``.
16312+ Defaults to ``("device", "dtype", "shape")``.
16313+ num_threads (int or None, optional): number of threads to use when
16314+ iterating leaves. Defaults to ``None`` (single-threaded). Construction
16315+ of :class:`TensorAttrs` is Python-bound, so threading typically yields
16316+ little; exposed for symmetry with :meth:`to`.
16317+
16318+ Examples:
16319+ >>> import torch
16320+ >>> from tensordict import TensorDict
16321+ >>> td = TensorDict(
16322+ ... {"a": torch.zeros(3, device="cpu"),
16323+ ... "b": torch.zeros(3, dtype=torch.int32)},
16324+ ... batch_size=[3],
16325+ ... )
16326+ >>> attrs = td.attrs(fields=("device", "dtype"))
16327+ >>> target = TensorDict(a=torch.zeros(3, device="cpu"), b=torch.zeros(3), batch_size=[3])
16328+ >>> out = target.to(attrs) # casts `b` to int32 per-leaf
16329+ >>> out["b"].dtype
16330+ torch.int32
16331+ """
16332+ from tensordict.tensorclass import TensorAttrs
16333+
16334+ def _to_attrs(t):
16335+ return TensorAttrs.from_tensor(t, fields=fields)
16336+
16337+ return self._fast_apply(
16338+ _to_attrs,
16339+ batch_size=(),
16340+ device=None,
16341+ propagate_lock=False,
16342+ is_leaf=_NESTED_TENSORS_AS_LISTS,
16343+ num_threads=num_threads if num_threads is not None else 0,
16344+ )
16345+
16346+ def _to_per_leaf(
16347+ self,
16348+ attrs_td: TensorDictBase,
16349+ *,
16350+ non_blocking: bool | None = None,
16351+ non_blocking_pin: bool = False,
16352+ num_threads: int | None = None,
16353+ inplace: bool = False,
16354+ ) -> Self:
16355+ """Cast each leaf of ``self`` to the attributes recorded in ``attrs_td``.
16356+
16357+ Leaves absent from ``attrs_td`` (or whose attrs have both ``tgt_device=None``
16358+ and ``tgt_dtype=None``) are passed through unchanged.
16359+
16360+ When the caller does not pass ``non_blocking`` explicitly, per-leaf copies
16361+ are issued asynchronously. A single :meth:`_sync_all` is invoked at the end
16362+ only if at least one leaf went D2H (source on a CUDA/XPU/etc. device, target
16363+ on CPU). Async H2D copies do not need an explicit sync — subsequent kernels
16364+ on the destination device serialize on the same CUDA stream that queued the
16365+ copy, so the dependency is already honored. Only the D2H direction needs the
16366+ barrier because host reads are not coordinated by the device's stream scheduler.
16367+ """
16368+ from tensordict.tensorclass import TensorAttrs
16369+
16370+ if non_blocking_pin:
16371+ raise NotImplementedError(
16372+ "non_blocking_pin is not yet supported when an attrs tensordict is passed to `to()`."
16373+ )
16374+
16375+ if non_blocking is None:
16376+ sub_non_blocking = True
16377+ do_sync = True
16378+ else:
16379+ sub_non_blocking = non_blocking
16380+ do_sync = not non_blocking
16381+
16382+ def _is_attrs_leaf(cls):
16383+ return issubclass(cls, TensorAttrs) or _default_is_leaf(cls)
16384+
16385+ spec: dict = {}
16386+ for key, val in attrs_td.items(
16387+ include_nested=True, leaves_only=True, is_leaf=_is_attrs_leaf
16388+ ):
16389+ if isinstance(val, TensorAttrs):
16390+ spec[key] = val
16391+
16392+ # D2H (device-to-host) is the only transfer direction that needs an explicit
16393+ # sync after an async copy: host memory is outside the source device's stream
16394+ # scheduler, so reads after the copy call returns may observe stale data. H2D
16395+ # and cross-device D2D are fine — the destination's stream already serializes
16396+ # on the enqueued copy.
16397+ needs_d2h_sync = False
16398+
16399+ def _cast(name, tensor):
16400+ attrs = spec.get(name)
16401+ if attrs is None:
16402+ return tensor
16403+ target_device = attrs.tgt_device
16404+ target_dtype = attrs.tgt_dtype
16405+ if target_device is None and target_dtype is None:
16406+ return tensor
16407+ if (
16408+ target_device is not None
16409+ and target_device.type == "cpu"
16410+ and tensor.device.type != "cpu"
16411+ ):
16412+ nonlocal needs_d2h_sync
16413+ needs_d2h_sync = True
16414+ return tensor.to(
16415+ device=target_device,
16416+ dtype=target_dtype,
16417+ non_blocking=sub_non_blocking,
16418+ )
16419+
16420+ result = self._fast_apply(
16421+ _cast,
16422+ named=True,
16423+ nested_keys=True,
16424+ is_leaf=_NESTED_TENSORS_AS_LISTS,
16425+ propagate_lock=True,
16426+ out=self if inplace else None,
16427+ checked=True,
16428+ device=None,
16429+ num_threads=num_threads if num_threads is not None else 0,
16430+ )
16431+
16432+ if needs_d2h_sync and do_sync:
16433+ self._sync_all()
16434+
16435+ return result
16436+
1628016437 def _to_consolidated(
1628116438 self, *, device, pin_memory, num_threads, non_blocking, inplace
1628216439 ):
0 commit comments