Skip to content

Commit a347761

Browse files
Add assembly of rectangular sparse matrices (#1279)
1 parent eb76ddc commit a347761

4 files changed

Lines changed: 141 additions & 45 deletions

File tree

ext/FerriteSparseMatrixCSR.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,39 @@ import Base: @propagate_inbounds
77
# Could be generalized if https://github.com/JuliaSparse/SparseArrays.jl/pull/546 is merged
88
function Ferrite.start_assemble(K::SparseMatrixCSR{<:Any, T}, f::Vector = T[]; fillzero::Bool = true, maxcelldofs_hint::Int = 0) where {T}
99
fillzero && (Ferrite.fillzero!(K); Ferrite.fillzero!(f))
10-
return CSRAssembler(K, f, zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint))
10+
return CSRAssembler(K, f, zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint))
1111
end
1212

13-
@propagate_inbounds function Ferrite._assemble_inner!(K::SparseMatrixCSR, Ke::AbstractMatrix, dofs::AbstractVector, sorteddofs::AbstractVector, permutation::AbstractVector, sym::Bool)
13+
@propagate_inbounds function Ferrite._assemble_inner!(
14+
K::SparseMatrixCSR, Ke::AbstractMatrix,
15+
rowdofs::AbstractVector, sortedrowdofs::AbstractVector, rowpermutation::AbstractVector,
16+
coldofs::AbstractVector, sortedcoldofs::AbstractVector, colpermutation::AbstractVector,
17+
sym::Bool
18+
)
1419
current_row = 1
15-
ld = length(dofs)
16-
return @inbounds for Krow in sorteddofs
20+
ld = length(coldofs)
21+
return @inbounds for Krow in sortedrowdofs
1722
maxlookups = sym ? current_row : ld
18-
Kerow = permutation[current_row]
23+
Kerow = rowpermutation[current_row]
1924
ci = 1 # col index pointer for the local matrix
2025
Ci = 1 # col index pointer for the global matrix
2126
nzr = nzrange(K, Krow)
2227
while Ci <= length(nzr) && ci <= maxlookups
2328
C = nzr[Ci]
2429
Kcol = K.colval[C]
25-
Kecol = permutation[ci]
30+
Kecol = colpermutation[ci]
2631
val = Ke[Kerow, Kecol]
27-
if Kcol == dofs[Kecol]
32+
if Kcol == coldofs[Kecol]
2833
# Match: add the value (if non-zero) and advance the pointers
2934
if !iszero(val)
3035
K.nzval[C] += val
3136
end
3237
ci += 1
3338
Ci += 1
34-
elseif Kcol < dofs[Kecol]
39+
elseif Kcol < coldofs[Kecol]
3540
# No match yet: advance the global matrix row pointer
3641
Ci += 1
37-
else # Kcol > dofs[Kecol]
42+
else # Kcol > coldofs[Kecol]
3843
# No match: no entry exist in the global matrix for this row. This is
3944
# allowed as long as the value which would have been inserted is zero.
4045
iszero(val) || Ferrite._missing_sparsity_pattern_error(Krow, Kcol)
@@ -44,8 +49,8 @@ end
4449
end
4550
# Make sure that remaining entries in this column of the local matrix are all zero
4651
for i in ci:maxlookups
47-
if !iszero(Ke[Kerow, permutation[i]])
48-
Ferrite._missing_sparsity_pattern_error(Krow, sorteddofs[i])
52+
if !iszero(Ke[Kerow, colpermutation[i]])
53+
Ferrite._missing_sparsity_pattern_error(Krow, sortedcoldofs[i])
4954
end
5055
end
5156
current_row += 1

src/assembler.jl

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ Assembler for sparse matrix with CSC storage type.
176176
struct CSCAssembler{Tv, Ti, MT <: AbstractSparseMatrixCSC{Tv, Ti}} <: AbstractCSCAssembler
177177
K::MT
178178
f::Vector{Tv}
179-
permutation::Vector{Int}
180-
sorteddofs::Vector{Int}
179+
rowpermutation::Vector{Int}
180+
colpermutation::Vector{Int}
181+
sortedrowdofs::Vector{Int}
182+
sortedcoldofs::Vector{Int}
181183
end
182184

183185
"""
@@ -186,8 +188,10 @@ Assembler for sparse matrix with CSR storage type.
186188
struct CSRAssembler{Tv, Ti, MT <: AbstractSparseMatrix{Tv, Ti}} <: AbstractCSRAssembler #AbstractSparseMatrixCSR does not exist
187189
K::MT
188190
f::Vector{Tv}
189-
permutation::Vector{Int}
190-
sorteddofs::Vector{Int}
191+
rowpermutation::Vector{Int}
192+
colpermutation::Vector{Int}
193+
sortedrowdofs::Vector{Int}
194+
sortedcoldofs::Vector{Int}
191195
end
192196

193197
"""
@@ -196,8 +200,10 @@ Assembler for symmetric sparse matrix with CSC storage type.
196200
struct SymmetricCSCAssembler{Tv, Ti, MT <: Symmetric{Tv, <:AbstractSparseMatrixCSC{Tv, Ti}}} <: AbstractCSCAssembler
197201
K::MT
198202
f::Vector{Tv}
199-
permutation::Vector{Int}
200-
sorteddofs::Vector{Int}
203+
rowpermutation::Vector{Int} # Symmetric assembly doesn't need separate row and
204+
colpermutation::Vector{Int} # col permutation and dofs, but simplifies code reuse
205+
sortedrowdofs::Vector{Int} # reuse with non-symmetric cases. sortedrowdofs and
206+
sortedcoldofs::Vector{Int} # rowpermutation always aliased to sortedcoldofs and colpermutation.
201207
end
202208

203209
function Base.show(io::IO, ::MIME"text/plain", a::Union{CSCAssembler, CSRAssembler, SymmetricCSCAssembler})
@@ -239,11 +245,13 @@ start_assemble(K::Union{AbstractSparseMatrixCSC, Symmetric{<:Any, <:AbstractSpar
239245

240246
function start_assemble(K::AbstractSparseMatrixCSC{T}, f::Vector = T[]; fillzero::Bool = true, maxcelldofs_hint::Int = 0) where {T}
241247
fillzero && (fillzero!(K); fillzero!(f))
242-
return CSCAssembler(K, f, zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint))
248+
return CSCAssembler(K, f, zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint))
243249
end
244250
function start_assemble(K::Symmetric{T, <:SparseMatrixCSC}, f::Vector = T[]; fillzero::Bool = true, maxcelldofs_hint::Int = 0) where {T}
245251
fillzero && (fillzero!(K); fillzero!(f))
246-
return SymmetricCSCAssembler(K, f, zeros(Int, maxcelldofs_hint), zeros(Int, maxcelldofs_hint))
252+
permutation = zeros(Int, maxcelldofs_hint)
253+
sorteddofs = zeros(Int, maxcelldofs_hint)
254+
return SymmetricCSCAssembler(K, f, permutation, permutation, sorteddofs, sorteddofs)
247255
end
248256

249257
function finish_assemble(a::Union{CSCAssembler, CSRAssembler, SymmetricCSCAssembler})
@@ -254,19 +262,29 @@ end
254262
assemble!(A::AbstractAssembler, dofs::AbstractVector{Int}, Ke::AbstractMatrix)
255263
assemble!(A::AbstractAssembler, dofs::AbstractVector{Int}, Ke::AbstractMatrix, fe::AbstractVector)
256264
257-
Assemble the element stiffness matrix `Ke` (and optional force vector `fe`) into the global
265+
Assemble the square element stiffness matrix `Ke` (and optional force vector `fe`) into the global
258266
stiffness (and force) in `A`, given the element degrees of freedom `dofs`.
259267
260-
This is equivalent to `K[dofs, dofs] += Ke` and `f[dofs] += fe`, where `K` is the global
261-
stiffness matrix and `f` the global force/residual vector, but more efficient.
268+
This is equivalent to `K[dofs, dofs] += Ke` and `f[dofs] += fe`, where `K` is the global stiffness matrix and `f` the global force/residual vector, but more efficient.
269+
270+
assemble!(A::AbstractAssembler, rowdofs::AbstractVector{Int}, coldofs::AbstractVector{Int}, Ke::AbstractMatrix)
271+
assemble!(A::AbstractAssembler, rowdofs::AbstractVector{Int}, coldofs::AbstractVector{Int}, Ke::AbstractMatrix, fe::AbstractVector)
272+
273+
Assemble the element stiffness matrix `Ke` (and optional force vector `fe`) into the global
274+
stiffness (and force) in `A`, given the element row degrees of freedom, `rowdofs`, and element column degrees of freedom, `coldofs`.
275+
This is equivalent to `K[rowdofs, coldofs] += Ke` and `f[rowdofs] += fe`, but more efficient.
262276
"""
263277
assemble!(::AbstractAssembler, ::AbstractVector{<:Integer}, ::AbstractMatrix, ::AbstractVector)
264278

265279
@propagate_inbounds function assemble!(A::AbstractAssembler, dofs::AbstractVector{<:Integer}, Ke::AbstractMatrix, fe::Union{AbstractVector, Nothing} = nothing)
266-
return _assemble!(A, dofs, Ke, fe, false)
280+
size(Ke, 1) == size(Ke, 2) || throw(ArgumentError("Ke is rectangular, but only a single `dofs` vector is provided. Please call assemble!(A, rowdofs, coldofs, Ke, fe) instead."))
281+
return _assemble!(A, dofs, dofs, Ke, fe, false)
282+
end
283+
@propagate_inbounds function assemble!(A::AbstractAssembler, rowdofs::AbstractVector{<:Integer}, coldofs::AbstractVector{<:Integer}, Ke::AbstractMatrix, fe::Union{AbstractVector, Nothing} = nothing)
284+
return _assemble!(A, rowdofs, coldofs, Ke, fe, false)
267285
end
268286
@propagate_inbounds function assemble!(A::SymmetricCSCAssembler, dofs::AbstractVector{<:Integer}, Ke::AbstractMatrix, fe::Union{AbstractVector, Nothing} = nothing)
269-
return _assemble!(A, dofs, Ke, fe, true)
287+
return _assemble!(A, dofs, dofs, Ke, fe, true)
270288
end
271289

272290
"""
@@ -283,53 +301,62 @@ Sorts the dofs into a separate buffer and returns it together with a permutation
283301
return sorteddofs, permutation
284302
end
285303

286-
@propagate_inbounds function _assemble!(A::Union{AbstractCSCAssembler, AbstractCSRAssembler}, dofs::AbstractVector{<:Integer}, Ke::AbstractMatrix, fe::Union{AbstractVector, Nothing}, sym::Bool)
287-
ld = length(dofs)
288-
@boundscheck checkbounds(Ke, keys(dofs), keys(dofs))
304+
@propagate_inbounds function _assemble!(A::Union{AbstractCSCAssembler, AbstractCSRAssembler}, rowdofs::AbstractVector{<:Integer}, coldofs::AbstractVector{<:Integer}, Ke::AbstractMatrix, fe::Union{AbstractVector, Nothing}, sym::Bool)
305+
@boundscheck checkbounds(Ke, keys(rowdofs), keys(coldofs))
289306
if fe !== nothing
290-
@boundscheck checkbounds(fe, keys(dofs))
291-
@boundscheck checkbounds(A.f, dofs)
292-
@inbounds assemble!(A.f, dofs, fe)
307+
@boundscheck checkbounds(fe, keys(rowdofs))
308+
@boundscheck checkbounds(A.f, rowdofs)
309+
@inbounds assemble!(A.f, rowdofs, fe)
293310
end
294311

295312
K = matrix_handle(A)
296-
@boundscheck checkbounds(K, dofs, dofs)
313+
@boundscheck checkbounds(K, rowdofs, coldofs)
297314

298315
# We assume that the input dofs are not sorted, because the cells need the dofs in
299316
# a specific order, which might not be the sorted order. Hence we sort them.
300317
# Note that we are not allowed to mutate `dofs` in the process.
301-
sorteddofs, permutation = _sortdofs_for_assembly!(A.permutation, A.sorteddofs, dofs)
318+
sortedcoldofs, colpermutation = _sortdofs_for_assembly!(A.colpermutation, A.sortedcoldofs, coldofs)
319+
sortedrowdofs, rowpermutation = if rowdofs !== coldofs
320+
_sortdofs_for_assembly!(A.rowpermutation, A.sortedrowdofs, rowdofs)
321+
else
322+
sortedcoldofs, colpermutation
323+
end
302324

303-
return _assemble_inner!(K, Ke, dofs, sorteddofs, permutation, sym)
325+
return _assemble_inner!(K, Ke, rowdofs, sortedrowdofs, rowpermutation, coldofs, sortedcoldofs, colpermutation, sym)
304326
end
305327

306-
@propagate_inbounds function _assemble_inner!(K::SparseMatrixCSC, Ke::AbstractMatrix, dofs::AbstractVector, sorteddofs::AbstractVector, permutation::AbstractVector, sym::Bool)
328+
@propagate_inbounds function _assemble_inner!(
329+
K::SparseMatrixCSC, Ke::AbstractMatrix,
330+
rowdofs::AbstractVector, sortedrowdofs::AbstractVector, rowpermutation::AbstractVector,
331+
coldofs::AbstractVector, sortedcoldofs::AbstractVector, colpermutation::AbstractVector,
332+
sym::Bool
333+
)
307334
current_col = 1
308335
Krows = rowvals(K)
309336
Kvals = nonzeros(K)
310-
ld = length(dofs)
311-
@inbounds for Kcol in sorteddofs
337+
ld = length(rowdofs)
338+
@inbounds for Kcol in sortedcoldofs
312339
maxlookups = sym ? current_col : ld
313-
Kecol = permutation[current_col]
340+
Kecol = colpermutation[current_col]
314341
ri = 1 # row index pointer for the local matrix
315342
Ri = 1 # row index pointer for the global matrix
316343
nzr = nzrange(K, Kcol)
317344
while Ri <= length(nzr) && ri <= maxlookups
318345
R = nzr[Ri]
319346
Krow = Krows[R]
320-
Kerow = permutation[ri]
347+
Kerow = rowpermutation[ri]
321348
val = Ke[Kerow, Kecol]
322-
if Krow == dofs[Kerow]
349+
if Krow == rowdofs[Kerow]
323350
# Match: add the value (if non-zero) and advance the pointers
324351
if !iszero(val)
325352
Kvals[R] += val
326353
end
327354
ri += 1
328355
Ri += 1
329-
elseif Krow < dofs[Kerow]
356+
elseif Krow < rowdofs[Kerow]
330357
# No match yet: advance the global matrix row pointer
331358
Ri += 1
332-
else # Krow > dofs[Kerow]
359+
else # Krow > rowdofs[Kerow]
333360
# No match: no entry exist in the global matrix for this row. This is
334361
# allowed as long as the value which would have been inserted is zero.
335362
iszero(val) || _missing_sparsity_pattern_error(Krow, Kcol)
@@ -339,8 +366,8 @@ end
339366
end
340367
# Make sure that remaining entries in this column of the local matrix are all zero
341368
for i in ri:maxlookups
342-
if !iszero(Ke[permutation[i], Kecol])
343-
_missing_sparsity_pattern_error(sorteddofs[i], Kcol)
369+
if !iszero(Ke[rowpermutation[i], Kecol])
370+
_missing_sparsity_pattern_error(sortedrowdofs[i], Kcol)
344371
end
345372
end
346373
current_col += 1

test/test_assemble.jl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using Ferrite, SparseArrays
2+
import LinearAlgebra: Symmetric
3+
14
@testset "assemble" begin
25
dofs = [1, 3, 5, 7]
36
maxd = maximum(dofs)
@@ -43,14 +46,44 @@
4346
@test size(K) == (10, 10)
4447
@test length(f) == 10
4548

46-
# assemble with different row and col dofs
49+
# COOAssembler: assemble with different row and col dofs
4750
rdofs = [1, 4, 6]
4851
cdofs = [1, 7]
4952
a = Ferrite.COOAssembler()
5053
Ke = rand(length(rdofs), length(cdofs))
5154
assemble!(a, rdofs, cdofs, Ke)
5255
K, _ = finish_assemble(a)
53-
@test (K[rdofs, cdofs] .== Ke) |> all
56+
@test all(K[rdofs, cdofs] .== Ke)
57+
58+
# CSCAssembler: assemble with different row and col dofs
59+
I = [1, 1, 4, 4, 6, 6]
60+
J = [1, 3, 1, 3, 1, 3]
61+
V = zeros(length(I))
62+
K = sparse(I, J, V)
63+
f = zeros(6)
64+
assembler = start_assemble(K, f)
65+
rdofs = [1, 4, 6]
66+
cdofs = [1, 3]
67+
Ke = rand(length(rdofs), length(cdofs))
68+
fe = rand(length(rdofs))
69+
assemble!(assembler, rdofs, cdofs, Ke, fe)
70+
assemble!(assembler, rdofs, cdofs, Ke, fe)
71+
@test_throws ArgumentError assemble!(assembler, rdofs, Ke, fe) # Not in sparsity pattern
72+
@test all(K[rdofs, cdofs] .== 2Ke)
73+
@test all(f[rdofs] .== 2fe)
74+
75+
# CSCAssembler: Assemble rectangular part in quadratic matrix
76+
K = SparseMatrixCSC(6, 6, [K.colptr..., 7, 7, 7], K.rowval, K.nzval)
77+
assembler = start_assemble(K, f)
78+
rdofs = [1, 4, 6]
79+
cdofs = [1, 3]
80+
Ke = rand(length(rdofs), length(cdofs))
81+
fe = rand(length(rdofs))
82+
assemble!(assembler, rdofs, cdofs, Ke, fe)
83+
assemble!(assembler, rdofs, cdofs, Ke, fe)
84+
@test_throws ArgumentError assemble!(assembler, rdofs, Ke, fe) # Not in sparsity pattern
85+
@test all(K[rdofs, cdofs] .== 2Ke)
86+
@test all(f[rdofs] .== 2fe)
5487

5588
# SparseMatrix assembler
5689
K = spzeros(10, 10)

test/test_assembler_extensions.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Ferrite
12
import SparseMatricesCSR: SparseMatrixCSR, sparsecsr
23
using SparseArrays, LinearAlgebra
34

@@ -84,6 +85,36 @@ using SparseArrays, LinearAlgebra
8485
@test K sparsecsr(I, J, V)
8586
@test f [4 / 3, 2.0, 1.0]
8687

88+
# CSRAssembler: assemble with different row and col dofs
89+
I = [1, 1, 4, 4, 6, 6]
90+
J = [1, 3, 1, 3, 1, 3]
91+
V = zeros(length(I))
92+
K = sparsecsr(I, J, V)
93+
f = zeros(6)
94+
assembler = start_assemble(K, f)
95+
rdofs = [1, 4, 6]
96+
cdofs = [1, 3]
97+
Ke = rand(length(rdofs), length(cdofs))
98+
fe = rand(length(rdofs))
99+
assemble!(assembler, rdofs, cdofs, Ke, fe)
100+
assemble!(assembler, rdofs, cdofs, Ke, fe)
101+
@test_throws ArgumentError assemble!(assembler, rdofs, Ke, fe) # Not in sparsity pattern
102+
@test all(K[rdofs, cdofs] .== 2Ke)
103+
@test all(f[rdofs] .== 2fe)
104+
105+
# CSRAssembler: Assemble rectangular part in quadratic matrix
106+
K = SparseMatrixCSR{1}(6, 6, K.rowptr, K.colval, K.nzval)
107+
assembler = start_assemble(K, f)
108+
rdofs = [1, 4, 6]
109+
cdofs = [1, 3]
110+
Ke = rand(length(rdofs), length(cdofs))
111+
fe = rand(length(rdofs))
112+
assemble!(assembler, rdofs, cdofs, Ke, fe)
113+
assemble!(assembler, rdofs, cdofs, Ke, fe)
114+
@test_throws ArgumentError assemble!(assembler, rdofs, Ke, fe) # Not in sparsity pattern
115+
@test all(K[rdofs, cdofs] .== 2Ke)
116+
@test all(f[rdofs] .== 2fe)
117+
87118
# Check if coupling works
88119
grid = generate_grid(Quadrilateral, (2, 2))
89120
ip = Lagrange{RefQuadrilateral, 1}()

0 commit comments

Comments
 (0)