Skip to content

Commit ccdad8b

Browse files
committed
Update
[ghstack-poisoned]
1 parent e82e34e commit ccdad8b

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

tensordict/base.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9463,6 +9463,197 @@ def _recv(
94639463

94649464
return _tag
94659465

9466+
def dtensor_send(
9467+
self,
9468+
dst,
9469+
*,
9470+
dst_mesh=None,
9471+
dst_placements=None,
9472+
strategy: str = "auto",
9473+
transport: str = "auto",
9474+
group=None,
9475+
) -> None:
9476+
"""Send a TensorDict containing DTensors to a remote worker or set of workers.
9477+
9478+
Supports three strategies for handling different sharding layouts between
9479+
sender and receiver, and two transport backends (torch.distributed, UCXX).
9480+
9481+
Args:
9482+
dst: the destination. Can be an ``int`` rank (for torch.distributed),
9483+
or a :class:`~tensordict._ucxx.TensorDictPipe` (for UCXX).
9484+
9485+
Keyword Args:
9486+
dst_mesh: the destination :class:`~torch.distributed.device_mesh.DeviceMesh`.
9487+
Required for ``strategy="optimal"``.
9488+
dst_placements: per-key destination placements. Can be a single
9489+
``tuple[Placement, ...]`` applied to all keys, or a
9490+
``dict[str, tuple[Placement, ...]]``. Required for
9491+
``strategy="optimal"``.
9492+
strategy (str): one of ``"materialize"`` (A), ``"redistribute"`` (B),
9493+
``"optimal"`` (C), or ``"auto"``. ``"auto"`` selects
9494+
``"optimal"`` when mesh info is available, else falls back to
9495+
``"materialize"``.
9496+
Defaults to ``"auto"``.
9497+
transport (str): one of ``"torch_distributed"``, ``"ucxx"``, or
9498+
``"auto"``. ``"auto"`` picks based on *dst* type.
9499+
Defaults to ``"auto"``.
9500+
group: ``torch.distributed`` process group. Only used with
9501+
``transport="torch_distributed"``.
9502+
"""
9503+
from tensordict._dtensor import _get_transport_backend
9504+
9505+
backend = _get_transport_backend(transport, dst, group=group)
9506+
9507+
resolved = strategy
9508+
if resolved == "auto":
9509+
resolved = "materialize"
9510+
9511+
if resolved == "materialize":
9512+
self._dtensor_send_materialize(dst, backend=backend)
9513+
elif resolved == "redistribute":
9514+
self._dtensor_send_redistribute(dst, backend=backend)
9515+
elif resolved == "optimal":
9516+
self._dtensor_send_optimal(
9517+
dst,
9518+
backend=backend,
9519+
dst_mesh=dst_mesh,
9520+
dst_placements=dst_placements,
9521+
)
9522+
else:
9523+
raise ValueError(
9524+
f"Unknown dtensor strategy {strategy!r}. "
9525+
"Expected 'materialize', 'redistribute', 'optimal', or 'auto'."
9526+
)
9527+
9528+
def dtensor_recv(
9529+
self,
9530+
src,
9531+
*,
9532+
src_mesh=None,
9533+
src_placements=None,
9534+
strategy: str = "auto",
9535+
transport: str = "auto",
9536+
group=None,
9537+
) -> None:
9538+
"""Receive a TensorDict containing DTensors from a remote worker.
9539+
9540+
This is the counterpart to :meth:`dtensor_send`.
9541+
9542+
Args:
9543+
src: the source. Can be an ``int`` rank (for torch.distributed),
9544+
or a :class:`~tensordict._ucxx.TensorDictPipe` (for UCXX).
9545+
9546+
Keyword Args:
9547+
src_mesh: the source :class:`~torch.distributed.device_mesh.DeviceMesh`.
9548+
Required for ``strategy="optimal"``.
9549+
src_placements: per-key source placements. Can be a single
9550+
``tuple[Placement, ...]`` applied to all keys, or a
9551+
``dict[str, tuple[Placement, ...]]``. Required for
9552+
``strategy="optimal"``.
9553+
strategy (str): must match the *strategy* used by the sender.
9554+
Defaults to ``"auto"``.
9555+
transport (str): one of ``"torch_distributed"``, ``"ucxx"``, or
9556+
``"auto"``. Defaults to ``"auto"``.
9557+
group: ``torch.distributed`` process group. Only used with
9558+
``transport="torch_distributed"``.
9559+
"""
9560+
from tensordict._dtensor import _get_transport_backend
9561+
9562+
backend = _get_transport_backend(transport, src, group=group)
9563+
9564+
resolved = strategy
9565+
if resolved == "auto":
9566+
resolved = "materialize"
9567+
9568+
if resolved == "materialize":
9569+
self._dtensor_recv_materialize(src, backend=backend)
9570+
elif resolved == "redistribute":
9571+
self._dtensor_recv_redistribute(src, backend=backend)
9572+
elif resolved == "optimal":
9573+
self._dtensor_recv_optimal(
9574+
src,
9575+
backend=backend,
9576+
src_mesh=src_mesh,
9577+
src_placements=src_placements,
9578+
)
9579+
else:
9580+
raise ValueError(
9581+
f"Unknown dtensor strategy {strategy!r}. "
9582+
"Expected 'materialize', 'redistribute', 'optimal', or 'auto'."
9583+
)
9584+
9585+
# -- Strategy A: materialize-and-reshard ----------------------------
9586+
9587+
def _dtensor_send_materialize(self, dst, *, backend) -> None:
9588+
"""Send by materializing DTensors to full tensors first."""
9589+
metadata = {}
9590+
tensors = []
9591+
for key in self.sorted_keys:
9592+
value = self._get_str(key, NO_DEFAULT)
9593+
if _is_tensor_collection(type(value)):
9594+
raise NotImplementedError(
9595+
"Nested TensorDicts in dtensor_send are not yet supported."
9596+
)
9597+
if hasattr(value, "full_tensor"):
9598+
placements_str = [str(p) for p in value.placements]
9599+
metadata[key] = {
9600+
"is_dtensor": True,
9601+
"shape": list(value.shape),
9602+
"dtype": str(value.dtype),
9603+
"placements": placements_str,
9604+
}
9605+
value = value.full_tensor()
9606+
else:
9607+
metadata[key] = {
9608+
"is_dtensor": False,
9609+
"shape": list(value.shape),
9610+
"dtype": str(value.dtype),
9611+
}
9612+
tensors.append((key, value))
9613+
9614+
dst_int = dst if isinstance(dst, int) else 0
9615+
backend.send_object(metadata, dst_int)
9616+
for key, tensor in tensors:
9617+
backend.send_tensor(tensor.contiguous(), dst_int)
9618+
9619+
def _dtensor_recv_materialize(self, src, *, backend) -> None:
9620+
"""Receive full tensors and wrap them back as DTensors if needed."""
9621+
src_int = src if isinstance(src, int) else 0
9622+
metadata = backend.recv_object(src_int)
9623+
9624+
for key, meta in metadata.items():
9625+
shape = torch.Size(meta["shape"])
9626+
dtype = getattr(torch, meta["dtype"].replace("torch.", ""))
9627+
buf = torch.empty(shape, dtype=dtype)
9628+
backend.recv_tensor(buf, src_int)
9629+
self._set_str(key, buf, inplace=False, validated=True)
9630+
9631+
# -- Strategy B / C stubs (implemented in later PRs) ----------------
9632+
9633+
def _dtensor_send_redistribute(self, dst, *, backend) -> None:
9634+
raise NotImplementedError(
9635+
"Strategy 'redistribute' is not yet implemented. "
9636+
"Use strategy='materialize' for now."
9637+
)
9638+
9639+
def _dtensor_recv_redistribute(self, src, *, backend) -> None:
9640+
raise NotImplementedError(
9641+
"Strategy 'redistribute' is not yet implemented. "
9642+
"Use strategy='materialize' for now."
9643+
)
9644+
9645+
def _dtensor_send_optimal(self, dst, *, backend, dst_mesh, dst_placements) -> None:
9646+
raise NotImplementedError(
9647+
"Strategy 'optimal' is not yet implemented. "
9648+
"Use strategy='materialize' for now."
9649+
)
9650+
9651+
def _dtensor_recv_optimal(self, src, *, backend, src_mesh, src_placements) -> None:
9652+
raise NotImplementedError(
9653+
"Strategy 'optimal' is not yet implemented. "
9654+
"Use strategy='materialize' for now."
9655+
)
9656+
94669657
def init_remote(
94679658
self,
94689659
dst: int | None = None,

0 commit comments

Comments
 (0)