diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 0f510975d982..9a8ba258b0cd 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -304,7 +304,7 @@ supported: - _propagate_xla_data - put_ - _pdist_forward - - qr + - linalg_qr - random_ - random_.from - random_.to diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 0047ae123fab..86a861900405 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -1,6 +1,8 @@ #include +#include #include +#include #include #include @@ -560,19 +562,39 @@ TEST_F(AtenXlaTensorTest, TestQR) { static const int dims[] = {4, 7}; for (auto m : dims) { for (auto n : dims) { - torch::Tensor a = - torch::rand({m, n}, torch::TensorOptions(torch::kFloat)); - auto b = torch::qr(a); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - auto xla_b = torch::qr(xla_a); - AllClose(std::get<0>(b).abs(), std::get<0>(xla_b).abs(), /*rtol=*/1e-3, - /*atol=*/1e-4); - AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(), /*rtol=*/1e-3, - /*atol=*/1e-4); - }); + for (const auto mode : {"reduced", "complete"}) { + torch::Tensor a = + torch::rand({m, n}, torch::TensorOptions(torch::kFloat)); + auto b = torch::linalg_qr(a, mode); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + auto xla_b = torch::linalg_qr(xla_a, mode); + int64_t k = mode == std::string("complete") ? m : std::min(m, n); + EXPECT_EQ(std::get<0>(xla_b).size(0), m); + EXPECT_EQ(std::get<0>(xla_b).size(1), k); + EXPECT_EQ(std::get<1>(xla_b).size(0), k); + EXPECT_EQ(std::get<1>(xla_b).size(1), n); + AllClose(std::get<0>(b).abs(), std::get<0>(xla_b).abs(), + /*rtol=*/1e-3, /*atol=*/1e-4); + AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(), + /*rtol=*/1e-3, /*atol=*/1e-4); + }); + } } } + + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = + CopyToDevice(torch::rand({4, 7}, torch::TensorOptions(torch::kFloat)), + device); + try { + torch::linalg_qr(xla_a, "raw"); + FAIL() << "Expected torch::linalg_qr to reject unsupported QR mode"; + } catch (const c10::Error& error) { + EXPECT_NE(std::string(error.what()).find("mode='raw'"), + std::string::npos); + } + }); } TEST_F(AtenXlaTensorTest, TestCholesky) { diff --git a/test/test_ops.py b/test/test_ops.py index 2e71948b76a1..02702c6ef1e8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -364,8 +364,6 @@ def get_allowed_ops_map( # AllowedOpInfoEntry('norm'), # AllowedOpInfoEntry('t'), # AllowedOpInfoEntry('logdet'), xla::lodget does not handle empty input - # AllowedOpInfoEntry('qr'), # Slice dim size 1 greater than dynamic slice dimension: 0 - # Worked locally (but failing on CI both CPU) # app.circleci.com/pipelines/github/pytorch/xla/9130/workflows/71c74f3d-1735-4328-81b5-784d6e6744da/jobs/17998 # AllowedOpInfoEntry('var_mean'), diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 1c855ca82397..f418b72236c7 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3118,9 +3118,18 @@ at::Tensor& XLANativeFunctions::put_(at::Tensor& self, const at::Tensor& index, return self; } -std::tuple XLANativeFunctions::qr( - const at::Tensor& self, bool some) { - TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); +std::tuple XLANativeFunctions::linalg_qr( + const at::Tensor& self, c10::string_view mode) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + bool some = true; + if (mode == "reduced") { + some = true; + } else if (mode == "complete") { + some = false; + } else { + TORCH_CHECK(false, "linalg_qr on XLA only supports modes 'reduced' and " + "'complete', but got mode='", mode, "'"); + } XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto results = tensor_methods::qr(xla_self, some); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index 4db31feb6ba9..b63c7200e41d 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -144,11 +144,10 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) // KERNEL_XLA(geqrf, fp32) // KERNEL_XLA(_lu_with_info, fp32) - KERNEL_XLA(qr, fp32) KERNEL_XLA(svd, fp32) KERNEL_XLA(triangular_solve, fp32) KERNEL_XLA(multilabel_margin_loss_forward, fp32) - // KERNEL_XLA(linalg_qr, fp32) + KERNEL_XLA(linalg_qr, fp32) // KERNEL_XLA(linalg_cholesky_ex, fp32) KERNEL_XLA(linalg_svd, fp32) // KERNEL_XLA(linalg_eig, fp32) diff --git a/torch_xla/csrc/ops/qr.cpp b/torch_xla/csrc/ops/qr.cpp index a067a77d41d6..0370d7d7c828 100644 --- a/torch_xla/csrc/ops/qr.cpp +++ b/torch_xla/csrc/ops/qr.cpp @@ -44,7 +44,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, bool some) { QR::QR(const torch::lazy::Value& input, bool some) : XlaNode( - torch::lazy::OpKind(at::aten::qr), {input}, + torch::lazy::OpKind(at::aten::linalg_qr), {input}, [&]() { return NodeOutputShape(input, some); }, /*num_outputs=*/2, torch::lazy::MHash(some)), some_(some) {}