Skip to content

Commit 7b28f2d

Browse files
authored
Bugfix and tests for cusolver/base (#2712)
1 parent cdb1e18 commit 7b28f2d

2 files changed

Lines changed: 65 additions & 23 deletions

File tree

lib/cusolver/base.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,73 +12,73 @@ version() = VersionNumber(cusolverGetProperty(CUDA.MAJOR_VERSION),
1212

1313
function Base.convert(::Type{cusolverEigType_t}, typ::Int)
1414
if typ == 1
15-
CUSOLVER_EIG_TYPE_1
15+
return CUSOLVER_EIG_TYPE_1
1616
elseif typ == 2
17-
CUSOLVER_EIG_TYPE_2
17+
return CUSOLVER_EIG_TYPE_2
1818
elseif typ == 3
19-
CUSOLVER_EIG_TYPE_3
19+
return CUSOLVER_EIG_TYPE_3
2020
else
2121
throw(ArgumentError("Unknown eigenvalue solver type $typ."))
2222
end
2323
end
2424

2525
function Base.convert(::Type{cusolverEigMode_t}, jobz::Char)
2626
if jobz == 'N'
27-
CUSOLVER_EIG_MODE_NOVECTOR
27+
return CUSOLVER_EIG_MODE_NOVECTOR
2828
elseif jobz == 'V'
29-
CUSOLVER_EIG_MODE_VECTOR
29+
return CUSOLVER_EIG_MODE_VECTOR
3030
else
3131
throw(ArgumentError("Unknown eigenvalue solver mode $jobz."))
3232
end
3333
end
3434

3535
function Base.convert(::Type{cusolverEigRange_t}, range::Char)
3636
if range == 'A'
37-
CUSOLVER_EIG_RANGE_ALL
37+
return CUSOLVER_EIG_RANGE_ALL
3838
elseif range == 'V'
39-
CUSOLVER_EIG_RANGE_V
39+
return CUSOLVER_EIG_RANGE_V
4040
elseif range == 'I'
41-
CUSOLVER_EIG_RANGE_I
41+
return CUSOLVER_EIG_RANGE_I
4242
else
4343
throw(ArgumentError("Unknown eigenvalue solver range $range."))
4444
end
4545
end
4646

4747
function Base.convert(::Type{cusolverStorevMode_t}, storev::Char)
4848
if storev == 'C'
49-
CUBLAS_STOREV_COLUMNWISE
49+
return CUBLAS_STOREV_COLUMNWISE
5050
elseif storev == 'R'
51-
CUBLAS_STOREV_ROWWISE
51+
return CUBLAS_STOREV_ROWWISE
5252
else
5353
throw(ArgumentError("Unknown storage mode $storev."))
5454
end
5555
end
5656

5757
function Base.convert(::Type{cusolverDirectMode_t}, direct::Char)
5858
if direct == 'F'
59-
CUBLAS_DIRECT_FORWARD
59+
return CUBLAS_DIRECT_FORWARD
6060
elseif direct == 'B'
61-
CUBLAS_DIRECT_BACKWARD
61+
return CUBLAS_DIRECT_BACKWARD
6262
else
6363
throw(ArgumentError("Unknown direction mode $direct."))
6464
end
6565
end
6666

6767
function Base.convert(::Type{cusolverIRSRefinement_t}, irs::String)
6868
if irs == "NOT_SET"
69-
CUSOLVER_IRS_REFINE_NOT_SET
69+
return CUSOLVER_IRS_REFINE_NOT_SET
7070
elseif irs == "NONE"
71-
CUSOLVER_IRS_REFINE_NONE
71+
return CUSOLVER_IRS_REFINE_NONE
7272
elseif irs == "CLASSICAL"
73-
CUSOLVER_IRS_REFINE_CLASSICAL
74-
elseif "CLASSICAL_GMRES"
75-
CUSOLVER_IRS_REFINE_CLASSICAL_GMRES
76-
elseif "GMRES"
77-
CUSOLVER_IRS_REFINE_GMRES
78-
elseif "GMRES_GMRES"
79-
CUSOLVER_IRS_REFINE_GMRES_GMRES
80-
elseif "GMRES_NOPCOND"
81-
CUSOLVER_IRS_REFINE_GMRES_NOPCOND
73+
return CUSOLVER_IRS_REFINE_CLASSICAL
74+
elseif irs == "CLASSICAL_GMRES"
75+
return CUSOLVER_IRS_REFINE_CLASSICAL_GMRES
76+
elseif irs == "GMRES"
77+
return CUSOLVER_IRS_REFINE_GMRES
78+
elseif irs == "GMRES_GMRES"
79+
return CUSOLVER_IRS_REFINE_GMRES_GMRES
80+
elseif irs == "GMRES_NOPCOND"
81+
return CUSOLVER_IRS_REFINE_GMRES_NOPCOND
8282
else
8383
throw(ArgumentError("Unknown iterative refinement solver $irs."))
8484
end

test/libraries/cusolver/base.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using CUDA.CUSOLVER
2+
3+
@testset "CUSOLVER helpers and types" begin
4+
@test convert(CUSOLVER.cusolverEigType_t, 1) == CUSOLVER.CUSOLVER_EIG_TYPE_1
5+
@test convert(CUSOLVER.cusolverEigType_t, 2) == CUSOLVER.CUSOLVER_EIG_TYPE_2
6+
@test convert(CUSOLVER.cusolverEigType_t, 3) == CUSOLVER.CUSOLVER_EIG_TYPE_3
7+
@test_throws ArgumentError("Unknown eigenvalue solver type 4.") convert(CUSOLVER.cusolverEigType_t, 4)
8+
9+
@test convert(CUSOLVER.cusolverEigMode_t, 'N') == CUSOLVER.CUSOLVER_EIG_MODE_NOVECTOR
10+
@test convert(CUSOLVER.cusolverEigMode_t, 'V') == CUSOLVER.CUSOLVER_EIG_MODE_VECTOR
11+
@test_throws ArgumentError("Unknown eigenvalue solver mode A.") convert(CUSOLVER.cusolverEigMode_t, 'A')
12+
13+
@test convert(CUSOLVER.cusolverEigRange_t, 'A') == CUSOLVER.CUSOLVER_EIG_RANGE_ALL
14+
@test convert(CUSOLVER.cusolverEigRange_t, 'V') == CUSOLVER.CUSOLVER_EIG_RANGE_V
15+
@test convert(CUSOLVER.cusolverEigRange_t, 'I') == CUSOLVER.CUSOLVER_EIG_RANGE_I
16+
@test_throws ArgumentError("Unknown eigenvalue solver range B.") convert(CUSOLVER.cusolverEigRange_t, 'B')
17+
18+
@test convert(CUSOLVER.cusolverStorevMode_t, 'C') == CUSOLVER.CUBLAS_STOREV_COLUMNWISE
19+
@test convert(CUSOLVER.cusolverStorevMode_t, 'R') == CUSOLVER.CUBLAS_STOREV_ROWWISE
20+
@test_throws ArgumentError("Unknown storage mode A.") convert(CUSOLVER.cusolverStorevMode_t, 'A')
21+
22+
@test convert(CUSOLVER.cusolverDirectMode_t, 'F') == CUSOLVER.CUBLAS_DIRECT_FORWARD
23+
@test convert(CUSOLVER.cusolverDirectMode_t, 'B') == CUSOLVER.CUBLAS_DIRECT_BACKWARD
24+
@test_throws ArgumentError("Unknown direction mode A.") convert(CUSOLVER.cusolverDirectMode_t, 'A')
25+
26+
@test convert(CUSOLVER.cusolverIRSRefinement_t, "NOT_SET") == CUSOLVER.CUSOLVER_IRS_REFINE_NOT_SET
27+
@test convert(CUSOLVER.cusolverIRSRefinement_t, "NONE") == CUSOLVER.CUSOLVER_IRS_REFINE_NONE
28+
@test convert(CUSOLVER.cusolverIRSRefinement_t, "CLASSICAL") == CUSOLVER.CUSOLVER_IRS_REFINE_CLASSICAL
29+
@test convert(CUSOLVER.cusolverIRSRefinement_t, "CLASSICAL_GMRES") == CUSOLVER.CUSOLVER_IRS_REFINE_CLASSICAL_GMRES
30+
@test convert(CUSOLVER.cusolverIRSRefinement_t, "GMRES") == CUSOLVER.CUSOLVER_IRS_REFINE_GMRES
31+
@test convert(CUSOLVER.cusolverIRSRefinement_t, "GMRES_GMRES") == CUSOLVER.CUSOLVER_IRS_REFINE_GMRES_GMRES
32+
@test convert(CUSOLVER.cusolverIRSRefinement_t, "GMRES_NOPCOND") == CUSOLVER.CUSOLVER_IRS_REFINE_GMRES_NOPCOND
33+
@test_throws ArgumentError("Unknown iterative refinement solver A.") convert(CUSOLVER.cusolverIRSRefinement_t, "A")
34+
35+
@test convert(CUSOLVER.cusolverPrecType_t, "R_16F") == CUSOLVER.CUSOLVER_R_16F
36+
@test convert(CUSOLVER.cusolverPrecType_t, "R_16BF") == CUSOLVER.CUSOLVER_R_16BF
37+
@test convert(CUSOLVER.cusolverPrecType_t, "R_TF32") == CUSOLVER.CUSOLVER_R_TF32
38+
@test convert(CUSOLVER.cusolverPrecType_t, "C_16F") == CUSOLVER.CUSOLVER_C_16F
39+
@test convert(CUSOLVER.cusolverPrecType_t, "C_16BF") == CUSOLVER.CUSOLVER_C_16BF
40+
@test convert(CUSOLVER.cusolverPrecType_t, "C_TF32") == CUSOLVER.CUSOLVER_C_TF32
41+
@test_throws ArgumentError("cusolverPrecType_t equivalent for input type A does not exist!") convert(CUSOLVER.cusolverPrecType_t, "A")
42+
end

0 commit comments

Comments
 (0)