Skip to content

Commit 93871d8

Browse files
authored
More CUSOLVER dense tests (#2723)
1 parent a8cf98a commit 93871d8

1 file changed

Lines changed: 99 additions & 19 deletions

File tree

test/libraries/cusolver/dense.jl

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,59 @@ k = 1
1111

1212
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
1313
@testset "gesv!" begin
14-
A = rand(elty, n, n)
15-
X = zeros(elty, n, p)
16-
B = rand(elty, n, p)
17-
dA = CuArray(A)
18-
dX = CuArray(X)
19-
dB = CuArray(B)
20-
CUSOLVER.gesv!(dX, dA, dB)
21-
tol = real(elty) |> eps |> sqrt
22-
dR = dB - dA * dX
23-
@test norm(dR) <= tol
14+
@testset "irs_precision = AUTO" begin
15+
A = rand(elty, n, n)
16+
X = zeros(elty, n, p)
17+
B = rand(elty, n, p)
18+
dA = CuArray(A)
19+
dX = CuArray(X)
20+
dB = CuArray(B)
21+
CUSOLVER.gesv!(dX, dA, dB)
22+
tol = real(elty) |> eps |> sqrt
23+
dR = dB - dA * dX
24+
@test norm(dR) <= tol
25+
end
26+
@testset "irs_precision = $elty" begin
27+
irs_precision = elty <: Real ? "R_" : "C_"
28+
irs_precision *= string(sizeof(real(elty)) * 8) * "F"
29+
A = rand(elty, n, n)
30+
X = zeros(elty, n, p)
31+
B = rand(elty, n, p)
32+
dA = CuArray(A)
33+
dX = CuArray(X)
34+
dB = CuArray(B)
35+
CUSOLVER.gesv!(dX, dA, dB; irs_precision=irs_precision)
36+
tol = real(elty) |> eps |> sqrt
37+
dR = dB - dA * dX
38+
@test norm(dR) <= tol
39+
end
2440
end
2541

2642
@testset "gels!" begin
27-
A = rand(elty, m, n)
28-
X = zeros(elty, n, p)
29-
B = A * rand(elty, n, p) # ensure that AX = B is consistent
30-
dA = CuArray(A)
31-
dX = CuArray(X)
32-
dB = CuArray(B)
33-
CUSOLVER.gels!(dX, dA, dB)
34-
tol = real(elty) |> eps |> sqrt
35-
dR = dB - dA * dX
43+
@testset "irs_precision = AUTO" begin
44+
A = rand(elty, m, n)
45+
X = zeros(elty, n, p)
46+
B = A * rand(elty, n, p) # ensure that AX = B is consistent
47+
dA = CuArray(A)
48+
dX = CuArray(X)
49+
dB = CuArray(B)
50+
CUSOLVER.gels!(dX, dA, dB)
51+
tol = real(elty) |> eps |> sqrt
52+
dR = dB - dA * dX
53+
end
54+
@testset "irs_precision = $elty" begin
55+
irs_precision = elty <: Real ? "R_" : "C_"
56+
irs_precision *= string(sizeof(real(elty)) * 8) * "F"
57+
A = rand(elty, m, n)
58+
X = zeros(elty, n, p)
59+
B = A * rand(elty, n, p) # ensure that AX = B is consistent
60+
dA = CuArray(A)
61+
dX = CuArray(X)
62+
dB = CuArray(B)
63+
CUSOLVER.gels!(dX, dA, dB; irs_precision=irs_precision)
64+
tol = real(elty) |> eps |> sqrt
65+
dR = dB - dA * dX
66+
end
3667
end
3768

3869
@testset "geqrf! -- orgqr!" begin
@@ -41,6 +72,14 @@ k = 1
4172
dA, τ = CUSOLVER.geqrf!(dA)
4273
CUSOLVER.orgqr!(dA, τ)
4374
@test dA' * dA I
75+
dB = CuArray(A)
76+
dB, τ_b = LAPACK.geqrf!(dB)
77+
LAPACK.orgqr!(dB, τ_b)
78+
@test dB dA
79+
@test τ τ_b
80+
dB, τ_b = LAPACK.geqrf!(dB, similar(τ_b))
81+
LAPACK.orgqr!(dB, τ_b)
82+
@test dB dA
4483
end
4584

4685
@testset "ormqr!" begin
@@ -160,6 +199,12 @@ k = 1
160199
h_ipiv = collect(d_ipiv)
161200
alu = LinearAlgebra.LU(h_A, convert(Vector{BlasInt},h_ipiv), zero(BlasInt))
162201
@test A Array(alu)
202+
d_B = CuArray(A)
203+
d_B,d_ipiv,info = LAPACK.getrf!(d_B)
204+
@test d_B d_A
205+
d_B = CuArray(A)
206+
d_B,d_ipiv,info = LAPACK.getrf!(d_B, similar(d_ipiv))
207+
@test d_B d_A
163208

164209
d_A,d_ipiv,info = CUSOLVER.getrf!(CUDA.zeros(elty,n,n))
165210
@test_throws LinearAlgebra.SingularException LinearAlgebra.checknonsingular(info)
@@ -172,6 +217,9 @@ k = 1
172217
B = rand(elty,n,n)
173218
d_B = CuArray(B)
174219
d_B = CUSOLVER.getrs!('N',d_A,d_ipiv,d_B)
220+
d_C = CuArray(B)
221+
d_C = LAPACK.getrs!('N',d_A,d_ipiv,d_C)
222+
@test d_C d_B
175223
h_B = collect(d_B)
176224
@test h_B A\B
177225
A = rand(elty,m,n)
@@ -188,13 +236,19 @@ k = 1
188236
A = rand(elty,n,n)
189237
A = A + A' #symmetric
190238
d_A = CuArray(A)
239+
d_B = CuArray(A)
240+
d_C = CuArray(A)
191241
d_A,d_ipiv,info = CUSOLVER.sytrf!('U',d_A)
192242
LinearAlgebra.checknonsingular(info)
193243
h_A = collect(d_A)
194244
h_ipiv = collect(d_ipiv)
195245
A, ipiv = LAPACK.sytrf!('U',A)
196246
@test ipiv == h_ipiv
197247
@test A h_A
248+
d_B, ipiv_b = LAPACK.sytrf!('U',d_B)
249+
@test d_B d_A
250+
d_C, ipiv_b = LAPACK.sytrf!('U',d_C, similar(ipiv_b))
251+
@test d_C d_A
198252
A = rand(elty,m,n)
199253
d_A = CuArray(A)
200254
@test_throws DimensionMismatch CUSOLVER.sytrf!('U',d_A)
@@ -206,6 +260,7 @@ k = 1
206260
@testset "gebrd!" begin
207261
A = rand(elty,m,n)
208262
d_A = CuArray(A)
263+
d_B = CuArray(A)
209264
d_A, d_D, d_E, d_TAUQ, d_TAUP = CUSOLVER.gebrd!(d_A)
210265
h_A = collect(d_A)
211266
h_D = collect(d_D)
@@ -218,6 +273,8 @@ k = 1
218273
@test e[min(m,n)-1] h_E[min(m,n)-1]
219274
@test q h_TAUQ
220275
@test p h_TAUP
276+
d_B, d_D, d_E, d_TAUQ, d_TAUP = LAPACK.gebrd!(d_B)
277+
@test d_B d_A
221278
end
222279

223280
@testset "gesvd!" begin
@@ -233,6 +290,9 @@ k = 1
233290
d_A = CuMatrix(A)
234291
U2, Σ2, Vt2 = CUSOLVER.gesvd!(jobu, jobvt, d_A)
235292
@test Σ Σ2
293+
d_A = CuMatrix(A)
294+
U2, Σ2, Vt2 = LAPACK.gesvd!(jobu, jobvt, d_A)
295+
@test Σ Σ2
236296
end
237297
end
238298
end
@@ -244,8 +304,20 @@ k = 1
244304
local d_W, d_V
245305
if( elty <: Complex )
246306
d_W, d_V = CUSOLVER.heevd!('V','U', d_A)
307+
d_W_b, d_V_b = LAPACK.syev!('V','U', CuArray(A))
308+
@test d_W d_W_b
309+
@test d_V d_V_b
310+
d_W_c, d_V_c = LAPACK.syevd!('V','U', CuArray(A))
311+
@test d_W d_W_c
312+
@test d_V d_V_c
247313
else
248314
d_W, d_V = CUSOLVER.syevd!('V','U', d_A)
315+
d_W_b, d_V_b = LAPACK.syev!('V','U', CuArray(A))
316+
@test d_W d_W_b
317+
@test d_V d_V_b
318+
d_W_c, d_V_c = LAPACK.syevd!('V','U', CuArray(A))
319+
@test d_W d_W_c
320+
@test d_V d_V_c
249321
end
250322
h_W = collect(d_W)
251323
h_V = collect(d_V)
@@ -291,8 +363,16 @@ k = 1
291363
local d_W, d_VA, d_VB
292364
if( elty <: Complex )
293365
d_W, d_VA, d_VB = CUSOLVER.hegvd!(1, 'V','U', d_A, d_B)
366+
d_W2, d_VA2, d_VB2 = LAPACK.sygvd!(1, 'V','U', CuArray(A), CuArray(B))
367+
@test d_W2 d_W
368+
@test d_VA2 d_VA
369+
@test d_VB2 d_VB
294370
else
295371
d_W, d_VA, d_VB = CUSOLVER.sygvd!(1, 'V','U', d_A, d_B)
372+
d_W2, d_VA2, d_VB2 = LAPACK.sygvd!(1, 'V','U', CuArray(A), CuArray(B))
373+
@test d_W2 d_W
374+
@test d_VA2 d_VA
375+
@test d_VB2 d_VB
296376
end
297377
h_W = collect(d_W)
298378
h_VA = collect(d_VA)

0 commit comments

Comments
 (0)