@@ -56,15 +56,19 @@ def _placement_is_partial(p) -> bool:
5656# ---------------------------------------------------------------------------
5757
5858
59- @dataclass (frozen = True )
59+ @dataclass (frozen = True , slots = True )
6060class _ShardSpec :
6161 """What a single rank holds, expressed as slices into the global tensor."""
6262
6363 rank : int
6464 slices : tuple [slice , ...]
6565
66+ def __repr__ (self ) -> str :
67+ slices_str = ", " .join (f"{ s .start } :{ s .stop } " for s in self .slices )
68+ return f"ShardSpec(rank={ self .rank } , slices=[{ slices_str } ])"
6669
67- @dataclass (frozen = True )
70+
71+ @dataclass (frozen = True , slots = True )
6872class _ChunkTransfer :
6973 """One point-to-point transfer instruction."""
7074
@@ -74,8 +78,27 @@ class _ChunkTransfer:
7478 dst_slices : tuple [slice , ...]
7579 global_slices : tuple [slice , ...]
7680
81+ @property
82+ def numel (self ) -> int :
83+ """Number of elements this transfer moves."""
84+ n = 1
85+ for s in self .global_slices :
86+ n *= s .stop - s .start
87+ return n
88+
89+ def nbytes (self , itemsize : int = 1 ) -> int :
90+ """Number of bytes (numel * itemsize) this transfer moves."""
91+ return self .numel * itemsize
92+
93+ def __repr__ (self ) -> str :
94+ gl = ", " .join (f"{ s .start } :{ s .stop } " for s in self .global_slices )
95+ return (
96+ f"ChunkTransfer(src={ self .src_rank } ->dst={ self .dst_rank } , "
97+ f"global=[{ gl } ], numel={ self .numel } )"
98+ )
99+
77100
78- @dataclass
101+ @dataclass ( slots = True )
79102class _TransferPlan :
80103 """Complete plan for transferring one tensor between two sharding specs."""
81104
@@ -88,6 +111,82 @@ def sends_for_rank(self, rank: int) -> list[_ChunkTransfer]:
88111 def recvs_for_rank (self , rank : int ) -> list [_ChunkTransfer ]:
89112 return [t for t in self .transfers if t .dst_rank == rank ]
90113
114+ def __repr__ (self ) -> str :
115+ return (
116+ f"TransferPlan(shape={ tuple (self .global_shape )} , "
117+ f"transfers={ len (self .transfers )} )"
118+ )
119+
120+
121+ @dataclass (frozen = True )
122+ class ShardingDescriptor :
123+ """Describes how a logical tensor is distributed across a device mesh.
124+
125+ This is the bridge between framework-specific metadata (Megatron's
126+ partition_dim, vLLM's column/row parallel, FSDP's DTensor placements)
127+ and tensordict's framework-agnostic transfer plan computation.
128+ """
129+
130+ mesh_shape : tuple [int , ...]
131+ placements : tuple
132+ logical_shape : torch .Size
133+ rank_map : dict [tuple [int , ...], int ] | None = None
134+
135+ @classmethod
136+ def from_dtensor (cls , dtensor ) -> ShardingDescriptor :
137+ """Construct from a ``torch.distributed.tensor.DTensor``."""
138+ mesh = dtensor .device_mesh
139+ return cls (
140+ mesh_shape = tuple (mesh .mesh .shape ),
141+ placements = tuple (dtensor .placements ),
142+ logical_shape = dtensor .shape ,
143+ rank_map = _mesh_to_rank_map (mesh ),
144+ )
145+
146+ @classmethod
147+ def from_device_mesh (
148+ cls ,
149+ mesh ,
150+ placements : Sequence ,
151+ logical_shape : torch .Size ,
152+ ) -> ShardingDescriptor :
153+ """Construct from a DeviceMesh + placements + shape."""
154+ return cls (
155+ mesh_shape = tuple (mesh .mesh .shape ),
156+ placements = tuple (placements ),
157+ logical_shape = logical_shape ,
158+ rank_map = _mesh_to_rank_map (mesh ),
159+ )
160+
161+ @classmethod
162+ def replicated (cls , shape : torch .Size , world_size : int ) -> ShardingDescriptor :
163+ """All ranks hold a full copy."""
164+ from torch .distributed .tensor .placement_types import Replicate
165+
166+ return cls (
167+ mesh_shape = (world_size ,),
168+ placements = (Replicate (),),
169+ logical_shape = shape ,
170+ )
171+
172+ @classmethod
173+ def sharded (
174+ cls ,
175+ shape : torch .Size ,
176+ dim : int ,
177+ world_size : int ,
178+ rank_map : dict | None = None ,
179+ ) -> ShardingDescriptor :
180+ """Simple 1D shard on a single dimension."""
181+ from torch .distributed .tensor .placement_types import Shard
182+
183+ return cls (
184+ mesh_shape = (world_size ,),
185+ placements = (Shard (dim ),),
186+ logical_shape = shape ,
187+ rank_map = rank_map ,
188+ )
189+
91190
92191# ---------------------------------------------------------------------------
93192# Slice arithmetic
@@ -292,6 +391,51 @@ def _compute_transfer_plan(
292391 return plan
293392
294393
394+ def execute_transfer_plan (
395+ plan : _TransferPlan ,
396+ src_tensor : Tensor | None ,
397+ dst_buffer : Tensor | None ,
398+ rank : int ,
399+ backend : _TransportBackend ,
400+ ) -> None :
401+ """Execute a single-tensor transfer plan on this rank.
402+
403+ This rank participates as sender, receiver, or both, depending on
404+ whether it appears in the plan's transfers.
405+
406+ Args:
407+ plan: the precomputed TransferPlan.
408+ src_tensor: this rank's local shard (if rank is a source).
409+ Can be ``None`` if this rank is destination-only.
410+ dst_buffer: pre-allocated buffer for received data (if rank is
411+ a destination). Can be ``None`` if this rank is source-only.
412+ rank: this rank's global rank ID.
413+ backend: transport backend to use.
414+ """
415+ sends = plan .sends_for_rank (rank )
416+ recvs = plan .recvs_for_rank (rank )
417+
418+ recv_bufs : list [tuple [Tensor , _ChunkTransfer ]] = []
419+ if dst_buffer is not None :
420+ for transfer in recvs :
421+ chunk_shape = tuple (
422+ s .stop - s .start for s in transfer .global_slices
423+ )
424+ buf = torch .empty (
425+ chunk_shape , dtype = dst_buffer .dtype , device = dst_buffer .device
426+ )
427+ backend .recv_tensor (buf , transfer .src_rank )
428+ recv_bufs .append ((buf , transfer ))
429+
430+ if src_tensor is not None :
431+ for transfer in sends :
432+ chunk = src_tensor [transfer .src_slices ].contiguous ()
433+ backend .send_tensor (chunk , transfer .dst_rank )
434+
435+ for buf , transfer in recv_bufs :
436+ dst_buffer [transfer .dst_slices ].copy_ (buf )
437+
438+
295439# ---------------------------------------------------------------------------
296440# Transport abstraction
297441# ---------------------------------------------------------------------------
0 commit comments