Skip to content

Why does ReplicateShardedData perform an uncached compile+device execute on each sharded output to return to host #9751

@jameszianxuTT

Description

@jameszianxuTT

We've worked around the sharded output issue #9726, and are now able to produce sharded outputs using the shardy path introduced in #9348.

However, we observe that gathering sharded outputs to host (eg. sharded_tensor.to("cpu")) uses the ReplicateShardedData function, which takes a sharded output, then compiles and executes an effective no-op computation that tags its output as replicated, in order to gather a sharded tensor into a replicated tensor.

We have a few questions about this implementation:

  1. At the point of ReplicateShardedData, it seems like we have enough information from the OpSharding to reassemble the shards on host, rather than dispatching a device execute. Why aren't shards pulled to host and reassembled in host memory instead? Is this just for convenience (i.e. reusing existing unsharding infrastructure) or is there a more fundamental reason?
  2. Why isn't the compilation for the x+=0 graph in ReplicateShardedData cached in the computation cache?

A drawback of the existing implementation described in (1) is that it forces an expensive on-device all-gather which seems unnecessary, and may overrun device memory if the tensor being all gathered (eg. a large KV cache) actually cannot fit in device memory.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions