From d222c9e1d95debdc660f3f801a92ff873e58b9a0 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 13 Dec 2021 14:43:51 +0000 Subject: [PATCH 1/8] Add lowering for torch.linalg.svd In PyTorch core the plan is to remove `torch.svd`, it is replaced with `torch.linalg.svd`. In ATen there are two different operations: 1. `aten::svd` (old) 2. `aten::linalg_svd` (new) This PR adds XLA lowering for `linalg_svd`. Resolves https://github.com/pytorch/xla/issues/2755 Unblocks https://github.com/pytorch/pytorch/pull/57772 --- test/cpp/test_aten_xla_tensor.cpp | 26 +++++++++++ torch_xla/csrc/aten_xla_type.cpp | 10 +++++ torch_xla/csrc/ops/svd.cpp | 72 +++++++++++++++++++++++++++++++ torch_xla/csrc/ops/svd.h | 16 +++++++ 4 files changed, 124 insertions(+) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 1d155609cce4..1eba19636ebc 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -846,6 +846,32 @@ TEST_F(AtenXlaTensorTest, TestSVD) { } } +TEST_F(AtenXlaTensorTest, TestLinalgSVD) { + static const int dims[] = {4, 7}; + for (auto m : dims) { + for (auto n : dims) { + torch::Tensor a = + torch::rand({m, n}, torch::TensorOptions(torch::kFloat)); + auto b = torch::linalg::svd(a, /*full_matrices=*/false); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + auto xla_b = torch::linalg::svd(xla_a, /*full_matrices=*/false); + // The U and V matrices might have different sign for column vectors, so + // cannot be compared if not by absolute value. + AllClose(std::get<0>(b).abs(), std::get<0>(xla_b).abs(), /*rtol=*/1e-3, + /*atol=*/1e-4); + torch::Tensor diag = std::get<1>(b); + torch::Tensor xla_diag = std::get<1>(xla_b); + ASSERT_EQ(diag.sizes(), xla_diag.sizes()); + AllClose(diag, xla_diag, /*rtol=*/1e-3, + /*atol=*/1e-4); + AllClose(std::get<2>(b).abs(), std::get<2>(xla_b).abs(), /*rtol=*/1e-3, + /*atol=*/1e-4); + }); + } + } +} + TEST_F(AtenXlaTensorTest, TestQR) { static const int dims[] = {4, 7}; for (auto m : dims) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 76b4afee2685..bc8aa36e5e4c 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3257,6 +3257,16 @@ std::tuple XLANativeFunctions::svd( bridge::AtenFromXlaTensor(std::get<2>(results))); } +std::tuple XLANativeFunctions::linalg_svd( + const at::Tensor& self, bool full_matrices) { + XLA_FN_COUNTER("xla::"); + auto results = + XLATensor::linalg_svd(bridge::GetXlaTensor(self), full_matrices); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results)), + bridge::AtenFromXlaTensor(std::get<2>(results))); +} + std::tuple XLANativeFunctions::symeig( const at::Tensor& self, bool eigenvectors, bool upper) { XLA_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/ops/svd.cpp b/torch_xla/csrc/ops/svd.cpp index f081a59c666b..adcda716710b 100644 --- a/torch_xla/csrc/ops/svd.cpp +++ b/torch_xla/csrc/ops/svd.cpp @@ -63,6 +63,55 @@ xla::Shape NodeOutputShape(const XlaValue& input, bool some, bool compute_uv) { return xla::ShapeUtil::MakeTupleShape({ushape, dshape, vshape}); } +std::vector LowerLinalgSVD(xla::XlaOp input, bool full_matrices) { + xla::SVDResult svd_result = + xla::SVD(input, /*max_iter=*/100, /*epsilon=*/1e-6, + XlaHelpers::mat_mul_precision()); + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp u = svd_result.u; + xla::XlaOp v = svd_result.v; + if (!full_matrices) { + xla::int64 m_dim = input_shape.dimensions(input_shape.rank() - 2); + xla::int64 n_dim = input_shape.dimensions(input_shape.rank() - 1); + std::vector base_indices(input_shape.rank(), 0); + + auto u_sizes = xla::util::ToVector(input_shape.dimensions()); + u_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim); + u = BuildSlice(u, base_indices, u_sizes); + + auto v_sizes = xla::util::ToVector(input_shape.dimensions()); + v_sizes[input_shape.rank() - 2] = n_dim; + v_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim); + v = BuildSlice(v, base_indices, v_sizes); + } + auto permute_dims = XlaHelpers::MakeTransposePermutation( + /*dim0=*/input_shape.rank() - 2, /*dim1=*/input_shape.rank() - 1, + /*rank=*/input_shape.rank()); + xla::XlaOp vh = xla::Transpose(v, permute_dims); + return {u, svd_result.d, vh}; +} + +xla::Shape NodeOutputShape(const Value& input, bool full_matrices) { + const xla::Shape& input_shape = input.shape(); + XLA_CHECK_GE(input_shape.rank(), 2) << input_shape; + // The input tensor is ..., M x N + xla::int64 m_dim = input_shape.dimensions(input_shape.rank() - 2); + xla::int64 n_dim = input_shape.dimensions(input_shape.rank() - 1); + // U is M x M or M x min(M, N) + xla::Shape ushape(input_shape); + ushape.set_dimensions(input_shape.rank() - 1, + full_matrices ? m_dim : std::min(m_dim, n_dim)); + // D is min(M, N). + xla::Shape dshape = xla::ShapeUtil::MakeShape(input_shape.element_type(), + {std::min(m_dim, n_dim)}); + // Vh is N x N or min(M, N) x N + xla::Shape vshape(input_shape); + vshape.set_dimensions(input_shape.rank() - 1, n_dim); + vshape.set_dimensions(input_shape.rank() - 2, + full_matrices ? n_dim : std::min(m_dim, n_dim)); + return xla::ShapeUtil::MakeTupleShape({ushape, dshape, vshape}); +} + } // namespace SVD::SVD(const XlaValue& input, bool some, bool compute_uv) @@ -88,4 +137,27 @@ std::string SVD::ToString() const { return ss.str(); } +LinalgSVD::LinalgSVD(const Value& input, bool full_matrices) + : Node( + ir::OpKind(at::aten::linalg_svd), {input}, + [&]() { return NodeOutputShape(input, full_matrices); }, + /*num_outputs=*/3, xla::util::MHash(full_matrices)), + full_matrices_(full_matrices) {} + +NodePtr LinalgSVD::Clone(OpList operands) const { + return MakeNode(operands.at(0), full_matrices_); +} + +XlaOpVector LinalgSVD::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + return ReturnOps(LowerLinalgSVD(input, full_matrices_), loctx); +} + +std::string LinalgSVD::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", some=" << some_ + << ", compute_uv=" << compute_uv_; + return ss.str(); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/svd.h b/torch_xla/csrc/ops/svd.h index a3bebaeabd52..0f105c2e98f6 100644 --- a/torch_xla/csrc/ops/svd.h +++ b/torch_xla/csrc/ops/svd.h @@ -23,4 +23,20 @@ class SVD : public XlaNode { bool compute_uv_; }; +class LinalgSVD : public Node { + public: + LinalgSVD(const Value& input, bool full_matrices); + + std::string ToString() const override; + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + bool full_matrices() const { return full_matrices_; } + + private: + bool full_matrices_; +}; + } // namespace torch_xla From 49735857bd4ca2eedce3e80429bd2a0eeb65c23d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 15 Dec 2021 10:26:34 +0000 Subject: [PATCH 2/8] clang-format-7 --- torch_xla/csrc/ops/svd.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/ops/svd.cpp b/torch_xla/csrc/ops/svd.cpp index adcda716710b..a384cb6b768e 100644 --- a/torch_xla/csrc/ops/svd.cpp +++ b/torch_xla/csrc/ops/svd.cpp @@ -138,10 +138,9 @@ std::string SVD::ToString() const { } LinalgSVD::LinalgSVD(const Value& input, bool full_matrices) - : Node( - ir::OpKind(at::aten::linalg_svd), {input}, - [&]() { return NodeOutputShape(input, full_matrices); }, - /*num_outputs=*/3, xla::util::MHash(full_matrices)), + : Node(ir::OpKind(at::aten::linalg_svd), {input}, + [&]() { return NodeOutputShape(input, full_matrices); }, + /*num_outputs=*/3, xla::util::MHash(full_matrices)), full_matrices_(full_matrices) {} NodePtr LinalgSVD::Clone(OpList operands) const { From 4a52e744939fc8073795dfc68209758198d48e50 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 15 Dec 2021 11:38:30 +0000 Subject: [PATCH 3/8] Add linalg_svd xla_native_functions.yaml --- torch_xla/csrc/aten_xla_type.cpp | 11 ++--- torch_xla/csrc/ops/svd.cpp | 71 -------------------------------- torch_xla/csrc/ops/svd.h | 16 ------- xla_native_functions.yaml | 1 + 4 files changed, 7 insertions(+), 92 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index bc8aa36e5e4c..8919ffc2870f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3260,11 +3260,12 @@ std::tuple XLANativeFunctions::svd( std::tuple XLANativeFunctions::linalg_svd( const at::Tensor& self, bool full_matrices) { XLA_FN_COUNTER("xla::"); - auto results = - XLATensor::linalg_svd(bridge::GetXlaTensor(self), full_matrices); - return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), - bridge::AtenFromXlaTensor(std::get<1>(results)), - bridge::AtenFromXlaTensor(std::get<2>(results))); + auto results = XLATensor::svd(bridge::GetXlaTensor(self), + /*some=*/!full_matrices, /*compute_uv=*/true); + return std::make_tuple( + bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results)), + bridge::AtenFromXlaTensor(std::get<2>(results)).conj().transpose(-2, -1)); } std::tuple XLANativeFunctions::symeig( diff --git a/torch_xla/csrc/ops/svd.cpp b/torch_xla/csrc/ops/svd.cpp index a384cb6b768e..f081a59c666b 100644 --- a/torch_xla/csrc/ops/svd.cpp +++ b/torch_xla/csrc/ops/svd.cpp @@ -63,55 +63,6 @@ xla::Shape NodeOutputShape(const XlaValue& input, bool some, bool compute_uv) { return xla::ShapeUtil::MakeTupleShape({ushape, dshape, vshape}); } -std::vector LowerLinalgSVD(xla::XlaOp input, bool full_matrices) { - xla::SVDResult svd_result = - xla::SVD(input, /*max_iter=*/100, /*epsilon=*/1e-6, - XlaHelpers::mat_mul_precision()); - const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); - xla::XlaOp u = svd_result.u; - xla::XlaOp v = svd_result.v; - if (!full_matrices) { - xla::int64 m_dim = input_shape.dimensions(input_shape.rank() - 2); - xla::int64 n_dim = input_shape.dimensions(input_shape.rank() - 1); - std::vector base_indices(input_shape.rank(), 0); - - auto u_sizes = xla::util::ToVector(input_shape.dimensions()); - u_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim); - u = BuildSlice(u, base_indices, u_sizes); - - auto v_sizes = xla::util::ToVector(input_shape.dimensions()); - v_sizes[input_shape.rank() - 2] = n_dim; - v_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim); - v = BuildSlice(v, base_indices, v_sizes); - } - auto permute_dims = XlaHelpers::MakeTransposePermutation( - /*dim0=*/input_shape.rank() - 2, /*dim1=*/input_shape.rank() - 1, - /*rank=*/input_shape.rank()); - xla::XlaOp vh = xla::Transpose(v, permute_dims); - return {u, svd_result.d, vh}; -} - -xla::Shape NodeOutputShape(const Value& input, bool full_matrices) { - const xla::Shape& input_shape = input.shape(); - XLA_CHECK_GE(input_shape.rank(), 2) << input_shape; - // The input tensor is ..., M x N - xla::int64 m_dim = input_shape.dimensions(input_shape.rank() - 2); - xla::int64 n_dim = input_shape.dimensions(input_shape.rank() - 1); - // U is M x M or M x min(M, N) - xla::Shape ushape(input_shape); - ushape.set_dimensions(input_shape.rank() - 1, - full_matrices ? m_dim : std::min(m_dim, n_dim)); - // D is min(M, N). - xla::Shape dshape = xla::ShapeUtil::MakeShape(input_shape.element_type(), - {std::min(m_dim, n_dim)}); - // Vh is N x N or min(M, N) x N - xla::Shape vshape(input_shape); - vshape.set_dimensions(input_shape.rank() - 1, n_dim); - vshape.set_dimensions(input_shape.rank() - 2, - full_matrices ? n_dim : std::min(m_dim, n_dim)); - return xla::ShapeUtil::MakeTupleShape({ushape, dshape, vshape}); -} - } // namespace SVD::SVD(const XlaValue& input, bool some, bool compute_uv) @@ -137,26 +88,4 @@ std::string SVD::ToString() const { return ss.str(); } -LinalgSVD::LinalgSVD(const Value& input, bool full_matrices) - : Node(ir::OpKind(at::aten::linalg_svd), {input}, - [&]() { return NodeOutputShape(input, full_matrices); }, - /*num_outputs=*/3, xla::util::MHash(full_matrices)), - full_matrices_(full_matrices) {} - -NodePtr LinalgSVD::Clone(OpList operands) const { - return MakeNode(operands.at(0), full_matrices_); -} - -XlaOpVector LinalgSVD::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOps(LowerLinalgSVD(input, full_matrices_), loctx); -} - -std::string LinalgSVD::ToString() const { - std::stringstream ss; - ss << Node::ToString() << ", some=" << some_ - << ", compute_uv=" << compute_uv_; - return ss.str(); -} - } // namespace torch_xla diff --git a/torch_xla/csrc/ops/svd.h b/torch_xla/csrc/ops/svd.h index 0f105c2e98f6..a3bebaeabd52 100644 --- a/torch_xla/csrc/ops/svd.h +++ b/torch_xla/csrc/ops/svd.h @@ -23,20 +23,4 @@ class SVD : public XlaNode { bool compute_uv_; }; -class LinalgSVD : public Node { - public: - LinalgSVD(const Value& input, bool full_matrices); - - std::string ToString() const override; - - NodePtr Clone(OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - bool full_matrices() const { return full_matrices_; } - - private: - bool full_matrices_; -}; - } // namespace torch_xla diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 5b0405b05b76..7965763eeec2 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -162,6 +162,7 @@ supported: - lerp.Scalar - lerp.Tensor - linspace + - linalg_svd - log - log1p - log2 From ffdc2dfb2aaf7596fb8fb4f4e204497e79f0e0d0 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 16 Dec 2021 13:49:03 +0000 Subject: [PATCH 4/8] Try skipping opinfo-based tests for linalg.svd; they generate 0x0 input that xla transpose doesn't like --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index ffa1ddbf74ff..0512d3da16e4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -193,8 +193,6 @@ def __new__(cls, name, variant_test_name=""): AllowedOpInfoEntry('linalg.solve'), AllowedOpInfoEntry('linalg.matrix_rank'), AllowedOpInfoEntry('einsum'), - AllowedOpInfoEntry('linalg.svd'), - AllowedOpInfoEntry('linalg.svdvals'), AllowedOpInfoEntry('polar'), AllowedOpInfoEntry('ravel'), AllowedOpInfoEntry('reshape'), @@ -339,6 +337,8 @@ def __new__(cls, name, variant_test_name=""): # AllowedOpInfoEntry('erfinv'), # AllowedOpInfoEntry('norm'), # AllowedOpInfoEntry('t'), + # AllowedOpInfoEntry('linalg.svd'), + # AllowedOpInfoEntry('linalg.svdvals'), # Failed on CUDA CI only (investigate) # app.circleci.com/pipelines/github/pytorch/xla/9088/workflows/2d59c649-db2b-4384-921e-5e43eba1b51a/jobs/17875 From 2c645b0a9d639e4ac0e0203fe82f59aabb5f8f15 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 12 May 2022 13:39:55 +0000 Subject: [PATCH 5/8] Try handling 0 numel separately --- torch_xla/csrc/aten_xla_type.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 8919ffc2870f..c054de87dfeb 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3260,6 +3260,21 @@ std::tuple XLANativeFunctions::svd( std::tuple XLANativeFunctions::linalg_svd( const at::Tensor& self, bool full_matrices) { XLA_FN_COUNTER("xla::"); + if (self.numel() == 0) { + auto sizes = self.sizes().vec(); + const auto m = sizes.cend()[-2]; + const auto n = sizes.cend()[-1]; + const auto k = std::min(m, n); + sizes.back() = full_matrices ? m : k; + auto U = at::zeros(sizes, self.options()); + sizes.end()[-2] = full_matrices ? n : k; + sizes.end()[-1] = n; + auto Vh = at::zeros(sizes, self.options()); + sizes.pop_back(); + sizes.end()[-1] = k; + auto S = at::zeros(sizes, self.options()); + return std::make_tuple(std::move(U), std::move(S), std::move(Vh)); + } auto results = XLATensor::svd(bridge::GetXlaTensor(self), /*some=*/!full_matrices, /*compute_uv=*/true); return std::make_tuple( From 3f0a82711a0592a85ee2f85fba23487a1f0ea815 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 19 May 2026 12:21:01 +0300 Subject: [PATCH 6/8] Validate SVD rank before empty-input handling --- torch_xla/csrc/aten_xla_type.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 56262b969946..66eae074c35a 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -4175,6 +4175,8 @@ std::tuple XLANativeFunctions::_linalg_svd( std::optional /* driver */) { // The optional driver string is only for CUDA with a cuSOLVER backend. TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + TORCH_CHECK(self.dim() >= 2, + "linalg.svd: The input tensor A must have at least 2 dimensions."); if (self.numel() == 0) { auto singular_values_sizes = self.sizes().vec(); const auto m = singular_values_sizes.cend()[-2]; From 19a316a7c3a69caa6f56d688946e8c09bc0053c9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 19 May 2026 12:25:56 +0300 Subject: [PATCH 7/8] Avoid duplicate SVD rank validation --- torch_xla/csrc/aten_xla_type.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 66eae074c35a..407236cf7f8d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -4175,12 +4175,10 @@ std::tuple XLANativeFunctions::_linalg_svd( std::optional /* driver */) { // The optional driver string is only for CUDA with a cuSOLVER backend. TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - TORCH_CHECK(self.dim() >= 2, - "linalg.svd: The input tensor A must have at least 2 dimensions."); if (self.numel() == 0) { + const auto m = self.size(-2); + const auto n = self.size(-1); auto singular_values_sizes = self.sizes().vec(); - const auto m = singular_values_sizes.cend()[-2]; - const auto n = singular_values_sizes.cend()[-1]; const auto k = std::min(m, n); singular_values_sizes.pop_back(); singular_values_sizes.back() = k; From ab18ba0c2b01b24598a3f498751727db3c4346fc Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 19 May 2026 12:27:24 +0300 Subject: [PATCH 8/8] Test empty XLA linalg SVD outputs --- test/cpp/test_aten_xla_tensor_2.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 0047ae123fab..7c63449ee7e0 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -2,6 +2,7 @@ #include #include +#include #include @@ -429,6 +430,30 @@ TEST_F(AtenXlaTensorTest, TestLinalgSVD) { ExpectCounterChanged("xla::_linalg_svd", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestLinalgSVDEmpty) { + std::vector> empty_sizes = { + {0, 0}, {0, 3}, {3, 0}, {2, 0, 3}, {2, 3, 0}}; + for (const auto& sizes : empty_sizes) { + for (bool full_matrices : {false, true}) { + for (bool compute_uv : {false, true}) { + torch::Tensor a = + torch::empty(sizes, torch::TensorOptions(torch::kFloat)); + auto expected = torch::_linalg_svd(a, full_matrices, compute_uv); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + auto actual = torch::_linalg_svd(xla_a, full_matrices, compute_uv); + ASSERT_EQ(std::get<0>(expected).sizes(), + std::get<0>(actual).sizes()); + ASSERT_EQ(std::get<1>(expected).sizes(), + std::get<1>(actual).sizes()); + ASSERT_EQ(std::get<2>(expected).sizes(), + std::get<2>(actual).sizes()); + }); + } + } + } +} + TEST_F(AtenXlaTensorTest, TestLinalgVectorNorm) { torch::Tensor a = torch::rand({4, 3}, torch::TensorOptions(torch::kFloat)); std::vector ords = {0.0, 1.5, std::numeric_limits::infinity(),