Skip to content

Commit 498ca38

Browse files
committed
Move code for CSR into extension
1 parent 9a3a0a3 commit 498ca38

2 files changed

Lines changed: 50 additions & 47 deletions

File tree

ext/FerriteSparseMatrixCSR.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module FerriteSparseMatrixCSR
22

33
using Ferrite, SparseArrays, SparseMatricesCSR
4-
import Ferrite: AbstractSparsityPattern, CSRAssembler
4+
import Ferrite: AbstractSparsityPattern, CSRAssembler, FastSparsityPattern, getnrows, getncols
55
import Base: @propagate_inbounds
66

77
# Could be generalized if https://github.com/JuliaSparse/SparseArrays.jl/pull/546 is merged
@@ -112,4 +112,51 @@ function _allocate_matrix(::Type{SparseMatrixCSR{1, Tv, Ti}}, sp::AbstractSparsi
112112
return S
113113
end
114114

115+
## ================= ##
116+
# FastSparsityPattern #
117+
## ================= ##
118+
119+
function _allocate_matrix(::Type{SparseMatrixCSR{1, Tv, Ti}}, sp::FastSparsityPattern{Ti}, sym::Bool) where {Tv, Ti}
120+
sym && throw(ArgumentError("FastSparsityPattern does not support symmetric matrices yet"))
121+
sp.is_colidx_sorted || sort_rows_threaded!(sp) # Require sorted rows
122+
nzval = zeros(Tv, length(sp.colidx))
123+
return SparseMatrixCSR{1}(getnrows(sp), getncols(sp), sp.rowptr, sp.colidx, nzval)
124+
end
125+
126+
function sort_rows!(sp)
127+
sort_rows!(sp, 1:getnrows(sp))
128+
sp.is_colidx_sorted = true
129+
return sp
130+
end
131+
132+
function sort_rows!(sp::FastSparsityPattern, rowrange::UnitRange)
133+
@inbounds for row in rowrange
134+
i1 = sp.rowptr[row]
135+
i2 = sp.rowptr[row + 1] - 1
136+
if i1 < i2
137+
sort!(view(sp.colidx, i1:i2))
138+
end
139+
end
140+
return sp
141+
end
142+
143+
function sort_rows_threaded!(
144+
sp::FastSparsityPattern, # Default ΔN ≥ 1000 and `n_tasks ≥ 1`
145+
ntasks = max(min(Threads.nthreads() * 100, getnrows(sp) ÷ 1000), 1)
146+
) # Otherwise, 100 per thread for load balancing
147+
nrows = getnrows(sp)
148+
ΔN = nrows ÷ ntasks
149+
Base.Experimental.@sync begin
150+
for taskid in 1:ntasks
151+
Threads.@spawn begin
152+
first_idx = 1 + ΔN * (taskid - 1)
153+
last_idx = min(first_idx + ΔN - 1, nrows)
154+
sort_rows!(sp, first_idx:last_idx)
155+
end
156+
end
157+
end
158+
sp.is_colidx_sorted = true
159+
return sp
160+
end
161+
115162
end

src/Dofs/sparsity_pattern.jl

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,9 @@ function _allocate_matrix(::Type{SparseMatrixCSC{Tv, Ti}}, sp::AbstractSparsityP
682682
return S
683683
end
684684

685-
## ================== ##
685+
## ================= ##
686686
# FastSparsityPattern #
687-
## ================== ##
687+
## ================= ##
688688

689689
# Full `AbstractSparsityPattern` interface not supported
690690
mutable struct FastSparsityPattern{Ti} <: AbstractSparsityPattern
@@ -829,47 +829,3 @@ function allocate_matrix(::Type{<:SparseMatrixCSC{Tv, Ti}}, sp::FastSparsityPatt
829829
nzval = zeros(Tv, nnz)
830830
return SparseMatrixCSC(nrows, ncols, colptr, rowidx, nzval)
831831
end
832-
833-
#= # TODO: Move to extension
834-
function sort_rows!(sp)
835-
sort_rows!(sp, 1:getnrows(sp))
836-
sp.is_sorted = true
837-
return sp
838-
end
839-
840-
function sort_rows!(sp::FastSparsityPattern, rowrange::UnitRange)
841-
@inbounds for row in rowrange
842-
i1 = sp.rowptr[row]
843-
i2 = sp.rowptr[row + 1] - 1
844-
if i1 < i2
845-
sort!(view(sp.colidx, i1:i2))
846-
end
847-
end
848-
return sp
849-
end
850-
851-
function sort_rows_threaded!(
852-
sp::FastSparsityPattern, # Default ΔN ≥ 1000 and `n_tasks ≥ 1`
853-
ntasks = max(min(Threads.nthreads() * 100, getnrows(sp) ÷ 1000), 1)
854-
) # Otherwise, 100 per thread for load balancing
855-
nrows = getnrows(sp)
856-
ΔN = nrows ÷ ntasks
857-
Base.Experimental.@sync begin
858-
for taskid in 1:ntasks
859-
Threads.@spawn begin
860-
first_idx = 1 + ΔN * (taskid - 1)
861-
last_idx = min(first_idx + ΔN - 1, nrows)
862-
sort_rows!(sp, first_idx:last_idx)
863-
end
864-
end
865-
end
866-
sp.is_sorted = true
867-
return sp
868-
end
869-
870-
function allocate_matrix(::Type{<:SparseMatrixCSR}, sp::FastSparsityPattern{Ti}) where {Ti}
871-
sp.is_sorted || sort_rows_threaded!(sp) # Require sorted rows
872-
nzval = zeros(Float64, length(sp.colidx))
873-
return SparseMatrixCSR{1}(ndofs(sp), ndofs(sp), sp.rowptr, sp.colidx, nzval)
874-
end
875-
=#

0 commit comments

Comments
 (0)