Handle empty inputs in XLA _linalg_svd#3252
Conversation
62e0bdd to
cb354eb
Compare
ce17370 to
d1cc628
Compare
|
It would be nice to go over this one again once pytorch/pytorch#69827 is merged. |
In PyTorch core the plan is to remove `torch.svd`, it is replaced with `torch.linalg.svd`. In ATen there are two different operations: 1. `aten::svd` (old) 2. `aten::linalg_svd` (new) This PR adds XLA lowering for `linalg_svd`. Resolves pytorch#2755 Unblocks pytorch/pytorch#57772
…ut that xla transpose doesn't like
Resolve conflicts by keeping current XLA SVD lowering and carrying forward empty-input handling for _linalg_svd.
There was a problem hiding this comment.
Pull request overview
This PR adds a special-case path for torch.linalg.svd lowering on empty XLA tensors, aligning XLA behavior with the newer ATen linalg_svd operator.
Changes:
- Handles
self.numel() == 0inside_linalg_svd. - Constructs empty-result shapes for
U,S, andVhbased onfull_matricesandcompute_uv.
Comments suppressed due to low confidence (3)
torch_xla/csrc/aten_xla_type.cpp:4185
- The singular values tensor is created with
self.options(), so complex inputs produce a complexS.torch.linalg.svdreturns real singular values for complex inputs, so this empty-input path should use the corresponding real dtype forswhile keepingu/vhin the input dtype.
auto s = at::zeros(singular_values_sizes, self.options());
torch_xla/csrc/aten_xla_type.cpp:4195
- When
full_matricesis true andn == 0, this returns a non-empty zero matrix forUwith shape(..., m, m). The SVD contract expects the returned singular vectors to be orthonormal/unitary; a zero square matrix is not, and it will differ from backends that return an identity basis for this empty-dimension case. Populate the full square factor with an identity basis instead of zeros.
auto u = at::zeros(u_sizes, self.options());
torch_xla/csrc/aten_xla_type.cpp:4200
- When
full_matricesis true andm == 0, this returns a non-empty zero matrix forVhwith shape(..., n, n). The full singular-vector factor should be unitary/orthonormal fortorch.linalg.svd; use an identity basis for the square factor instead of zeros in this empty-dimension case.
auto vh = at::zeros(vh_sizes, self.options());
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
After resolving conflicts against current |
|
Added focused C++ coverage in ab18ba0 for empty |
In PyTorch/XLA
torch.linalg.svdis implemented via_linalg_svd.The lowering already exists on
master; after resolving merge conflicts the only remaining issue is the zero-numel case.This PR handles empty inputs in
_linalg_svdand adds a small C++ test for the output shapes.Related: