@@ -9463,6 +9463,197 @@ def _recv(
94639463
94649464 return _tag
94659465
9466+ def dtensor_send(
9467+ self,
9468+ dst,
9469+ *,
9470+ dst_mesh=None,
9471+ dst_placements=None,
9472+ strategy: str = "auto",
9473+ transport: str = "auto",
9474+ group=None,
9475+ ) -> None:
9476+ """Send a TensorDict containing DTensors to a remote worker or set of workers.
9477+
9478+ Supports three strategies for handling different sharding layouts between
9479+ sender and receiver, and two transport backends (torch.distributed, UCXX).
9480+
9481+ Args:
9482+ dst: the destination. Can be an ``int`` rank (for torch.distributed),
9483+ or a :class:`~tensordict._ucxx.TensorDictPipe` (for UCXX).
9484+
9485+ Keyword Args:
9486+ dst_mesh: the destination :class:`~torch.distributed.device_mesh.DeviceMesh`.
9487+ Required for ``strategy="optimal"``.
9488+ dst_placements: per-key destination placements. Can be a single
9489+ ``tuple[Placement, ...]`` applied to all keys, or a
9490+ ``dict[str, tuple[Placement, ...]]``. Required for
9491+ ``strategy="optimal"``.
9492+ strategy (str): one of ``"materialize"`` (A), ``"redistribute"`` (B),
9493+ ``"optimal"`` (C), or ``"auto"``. ``"auto"`` selects
9494+ ``"optimal"`` when mesh info is available, else falls back to
9495+ ``"materialize"``.
9496+ Defaults to ``"auto"``.
9497+ transport (str): one of ``"torch_distributed"``, ``"ucxx"``, or
9498+ ``"auto"``. ``"auto"`` picks based on *dst* type.
9499+ Defaults to ``"auto"``.
9500+ group: ``torch.distributed`` process group. Only used with
9501+ ``transport="torch_distributed"``.
9502+ """
9503+ from tensordict._dtensor import _get_transport_backend
9504+
9505+ backend = _get_transport_backend(transport, dst, group=group)
9506+
9507+ resolved = strategy
9508+ if resolved == "auto":
9509+ resolved = "materialize"
9510+
9511+ if resolved == "materialize":
9512+ self._dtensor_send_materialize(dst, backend=backend)
9513+ elif resolved == "redistribute":
9514+ self._dtensor_send_redistribute(dst, backend=backend)
9515+ elif resolved == "optimal":
9516+ self._dtensor_send_optimal(
9517+ dst,
9518+ backend=backend,
9519+ dst_mesh=dst_mesh,
9520+ dst_placements=dst_placements,
9521+ )
9522+ else:
9523+ raise ValueError(
9524+ f"Unknown dtensor strategy {strategy!r}. "
9525+ "Expected 'materialize', 'redistribute', 'optimal', or 'auto'."
9526+ )
9527+
9528+ def dtensor_recv(
9529+ self,
9530+ src,
9531+ *,
9532+ src_mesh=None,
9533+ src_placements=None,
9534+ strategy: str = "auto",
9535+ transport: str = "auto",
9536+ group=None,
9537+ ) -> None:
9538+ """Receive a TensorDict containing DTensors from a remote worker.
9539+
9540+ This is the counterpart to :meth:`dtensor_send`.
9541+
9542+ Args:
9543+ src: the source. Can be an ``int`` rank (for torch.distributed),
9544+ or a :class:`~tensordict._ucxx.TensorDictPipe` (for UCXX).
9545+
9546+ Keyword Args:
9547+ src_mesh: the source :class:`~torch.distributed.device_mesh.DeviceMesh`.
9548+ Required for ``strategy="optimal"``.
9549+ src_placements: per-key source placements. Can be a single
9550+ ``tuple[Placement, ...]`` applied to all keys, or a
9551+ ``dict[str, tuple[Placement, ...]]``. Required for
9552+ ``strategy="optimal"``.
9553+ strategy (str): must match the *strategy* used by the sender.
9554+ Defaults to ``"auto"``.
9555+ transport (str): one of ``"torch_distributed"``, ``"ucxx"``, or
9556+ ``"auto"``. Defaults to ``"auto"``.
9557+ group: ``torch.distributed`` process group. Only used with
9558+ ``transport="torch_distributed"``.
9559+ """
9560+ from tensordict._dtensor import _get_transport_backend
9561+
9562+ backend = _get_transport_backend(transport, src, group=group)
9563+
9564+ resolved = strategy
9565+ if resolved == "auto":
9566+ resolved = "materialize"
9567+
9568+ if resolved == "materialize":
9569+ self._dtensor_recv_materialize(src, backend=backend)
9570+ elif resolved == "redistribute":
9571+ self._dtensor_recv_redistribute(src, backend=backend)
9572+ elif resolved == "optimal":
9573+ self._dtensor_recv_optimal(
9574+ src,
9575+ backend=backend,
9576+ src_mesh=src_mesh,
9577+ src_placements=src_placements,
9578+ )
9579+ else:
9580+ raise ValueError(
9581+ f"Unknown dtensor strategy {strategy!r}. "
9582+ "Expected 'materialize', 'redistribute', 'optimal', or 'auto'."
9583+ )
9584+
9585+ # -- Strategy A: materialize-and-reshard ----------------------------
9586+
9587+ def _dtensor_send_materialize(self, dst, *, backend) -> None:
9588+ """Send by materializing DTensors to full tensors first."""
9589+ metadata = {}
9590+ tensors = []
9591+ for key in self.sorted_keys:
9592+ value = self._get_str(key, NO_DEFAULT)
9593+ if _is_tensor_collection(type(value)):
9594+ raise NotImplementedError(
9595+ "Nested TensorDicts in dtensor_send are not yet supported."
9596+ )
9597+ if hasattr(value, "full_tensor"):
9598+ placements_str = [str(p) for p in value.placements]
9599+ metadata[key] = {
9600+ "is_dtensor": True,
9601+ "shape": list(value.shape),
9602+ "dtype": str(value.dtype),
9603+ "placements": placements_str,
9604+ }
9605+ value = value.full_tensor()
9606+ else:
9607+ metadata[key] = {
9608+ "is_dtensor": False,
9609+ "shape": list(value.shape),
9610+ "dtype": str(value.dtype),
9611+ }
9612+ tensors.append((key, value))
9613+
9614+ dst_int = dst if isinstance(dst, int) else 0
9615+ backend.send_object(metadata, dst_int)
9616+ for key, tensor in tensors:
9617+ backend.send_tensor(tensor.contiguous(), dst_int)
9618+
9619+ def _dtensor_recv_materialize(self, src, *, backend) -> None:
9620+ """Receive full tensors and wrap them back as DTensors if needed."""
9621+ src_int = src if isinstance(src, int) else 0
9622+ metadata = backend.recv_object(src_int)
9623+
9624+ for key, meta in metadata.items():
9625+ shape = torch.Size(meta["shape"])
9626+ dtype = getattr(torch, meta["dtype"].replace("torch.", ""))
9627+ buf = torch.empty(shape, dtype=dtype)
9628+ backend.recv_tensor(buf, src_int)
9629+ self._set_str(key, buf, inplace=False, validated=True)
9630+
9631+ # -- Strategy B / C stubs (implemented in later PRs) ----------------
9632+
9633+ def _dtensor_send_redistribute(self, dst, *, backend) -> None:
9634+ raise NotImplementedError(
9635+ "Strategy 'redistribute' is not yet implemented. "
9636+ "Use strategy='materialize' for now."
9637+ )
9638+
9639+ def _dtensor_recv_redistribute(self, src, *, backend) -> None:
9640+ raise NotImplementedError(
9641+ "Strategy 'redistribute' is not yet implemented. "
9642+ "Use strategy='materialize' for now."
9643+ )
9644+
9645+ def _dtensor_send_optimal(self, dst, *, backend, dst_mesh, dst_placements) -> None:
9646+ raise NotImplementedError(
9647+ "Strategy 'optimal' is not yet implemented. "
9648+ "Use strategy='materialize' for now."
9649+ )
9650+
9651+ def _dtensor_recv_optimal(self, src, *, backend, src_mesh, src_placements) -> None:
9652+ raise NotImplementedError(
9653+ "Strategy 'optimal' is not yet implemented. "
9654+ "Use strategy='materialize' for now."
9655+ )
9656+
94669657 def init_remote(
94679658 self,
94689659 dst: int | None = None,
0 commit comments