Skip to content

Commit 798bdbf

Browse files
committed
Update
[ghstack-poisoned]
2 parents 3893199 + 00f3904 commit 798bdbf

File tree

2 files changed

+404
-3
lines changed

2 files changed

+404
-3
lines changed

tensordict/_dtensor.py

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,19 @@ def _placement_is_partial(p) -> bool:
5656
# ---------------------------------------------------------------------------
5757

5858

59-
@dataclass(frozen=True)
59+
@dataclass(frozen=True, slots=True)
6060
class _ShardSpec:
6161
"""What a single rank holds, expressed as slices into the global tensor."""
6262

6363
rank: int
6464
slices: tuple[slice, ...]
6565

66+
def __repr__(self) -> str:
67+
slices_str = ", ".join(f"{s.start}:{s.stop}" for s in self.slices)
68+
return f"ShardSpec(rank={self.rank}, slices=[{slices_str}])"
6669

67-
@dataclass(frozen=True)
70+
71+
@dataclass(frozen=True, slots=True)
6872
class _ChunkTransfer:
6973
"""One point-to-point transfer instruction."""
7074

@@ -74,8 +78,27 @@ class _ChunkTransfer:
7478
dst_slices: tuple[slice, ...]
7579
global_slices: tuple[slice, ...]
7680

81+
@property
82+
def numel(self) -> int:
83+
"""Number of elements this transfer moves."""
84+
n = 1
85+
for s in self.global_slices:
86+
n *= s.stop - s.start
87+
return n
88+
89+
def nbytes(self, itemsize: int = 1) -> int:
90+
"""Number of bytes (numel * itemsize) this transfer moves."""
91+
return self.numel * itemsize
92+
93+
def __repr__(self) -> str:
94+
gl = ", ".join(f"{s.start}:{s.stop}" for s in self.global_slices)
95+
return (
96+
f"ChunkTransfer(src={self.src_rank}->dst={self.dst_rank}, "
97+
f"global=[{gl}], numel={self.numel})"
98+
)
99+
77100

78-
@dataclass
101+
@dataclass(slots=True)
79102
class _TransferPlan:
80103
"""Complete plan for transferring one tensor between two sharding specs."""
81104

@@ -88,6 +111,82 @@ def sends_for_rank(self, rank: int) -> list[_ChunkTransfer]:
88111
def recvs_for_rank(self, rank: int) -> list[_ChunkTransfer]:
89112
return [t for t in self.transfers if t.dst_rank == rank]
90113

114+
def __repr__(self) -> str:
115+
return (
116+
f"TransferPlan(shape={tuple(self.global_shape)}, "
117+
f"transfers={len(self.transfers)})"
118+
)
119+
120+
121+
@dataclass(frozen=True)
122+
class ShardingDescriptor:
123+
"""Describes how a logical tensor is distributed across a device mesh.
124+
125+
This is the bridge between framework-specific metadata (Megatron's
126+
partition_dim, vLLM's column/row parallel, FSDP's DTensor placements)
127+
and tensordict's framework-agnostic transfer plan computation.
128+
"""
129+
130+
mesh_shape: tuple[int, ...]
131+
placements: tuple
132+
logical_shape: torch.Size
133+
rank_map: dict[tuple[int, ...], int] | None = None
134+
135+
@classmethod
136+
def from_dtensor(cls, dtensor) -> ShardingDescriptor:
137+
"""Construct from a ``torch.distributed.tensor.DTensor``."""
138+
mesh = dtensor.device_mesh
139+
return cls(
140+
mesh_shape=tuple(mesh.mesh.shape),
141+
placements=tuple(dtensor.placements),
142+
logical_shape=dtensor.shape,
143+
rank_map=_mesh_to_rank_map(mesh),
144+
)
145+
146+
@classmethod
147+
def from_device_mesh(
148+
cls,
149+
mesh,
150+
placements: Sequence,
151+
logical_shape: torch.Size,
152+
) -> ShardingDescriptor:
153+
"""Construct from a DeviceMesh + placements + shape."""
154+
return cls(
155+
mesh_shape=tuple(mesh.mesh.shape),
156+
placements=tuple(placements),
157+
logical_shape=logical_shape,
158+
rank_map=_mesh_to_rank_map(mesh),
159+
)
160+
161+
@classmethod
162+
def replicated(cls, shape: torch.Size, world_size: int) -> ShardingDescriptor:
163+
"""All ranks hold a full copy."""
164+
from torch.distributed.tensor.placement_types import Replicate
165+
166+
return cls(
167+
mesh_shape=(world_size,),
168+
placements=(Replicate(),),
169+
logical_shape=shape,
170+
)
171+
172+
@classmethod
173+
def sharded(
174+
cls,
175+
shape: torch.Size,
176+
dim: int,
177+
world_size: int,
178+
rank_map: dict | None = None,
179+
) -> ShardingDescriptor:
180+
"""Simple 1D shard on a single dimension."""
181+
from torch.distributed.tensor.placement_types import Shard
182+
183+
return cls(
184+
mesh_shape=(world_size,),
185+
placements=(Shard(dim),),
186+
logical_shape=shape,
187+
rank_map=rank_map,
188+
)
189+
91190

92191
# ---------------------------------------------------------------------------
93192
# Slice arithmetic
@@ -292,6 +391,51 @@ def _compute_transfer_plan(
292391
return plan
293392

294393

394+
def execute_transfer_plan(
395+
plan: _TransferPlan,
396+
src_tensor: Tensor | None,
397+
dst_buffer: Tensor | None,
398+
rank: int,
399+
backend: _TransportBackend,
400+
) -> None:
401+
"""Execute a single-tensor transfer plan on this rank.
402+
403+
This rank participates as sender, receiver, or both, depending on
404+
whether it appears in the plan's transfers.
405+
406+
Args:
407+
plan: the precomputed TransferPlan.
408+
src_tensor: this rank's local shard (if rank is a source).
409+
Can be ``None`` if this rank is destination-only.
410+
dst_buffer: pre-allocated buffer for received data (if rank is
411+
a destination). Can be ``None`` if this rank is source-only.
412+
rank: this rank's global rank ID.
413+
backend: transport backend to use.
414+
"""
415+
sends = plan.sends_for_rank(rank)
416+
recvs = plan.recvs_for_rank(rank)
417+
418+
recv_bufs: list[tuple[Tensor, _ChunkTransfer]] = []
419+
if dst_buffer is not None:
420+
for transfer in recvs:
421+
chunk_shape = tuple(
422+
s.stop - s.start for s in transfer.global_slices
423+
)
424+
buf = torch.empty(
425+
chunk_shape, dtype=dst_buffer.dtype, device=dst_buffer.device
426+
)
427+
backend.recv_tensor(buf, transfer.src_rank)
428+
recv_bufs.append((buf, transfer))
429+
430+
if src_tensor is not None:
431+
for transfer in sends:
432+
chunk = src_tensor[transfer.src_slices].contiguous()
433+
backend.send_tensor(chunk, transfer.dst_rank)
434+
435+
for buf, transfer in recv_bufs:
436+
dst_buffer[transfer.dst_slices].copy_(buf)
437+
438+
295439
# ---------------------------------------------------------------------------
296440
# Transport abstraction
297441
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)