|
1 | 1 | module FerriteSparseMatrixCSR |
2 | 2 |
|
3 | 3 | using Ferrite, SparseArrays, SparseMatricesCSR |
4 | | -import Ferrite: AbstractSparsityPattern, CSRAssembler |
| 4 | +import Ferrite: AbstractSparsityPattern, CSRAssembler, FastSparsityPattern, getnrows, getncols |
5 | 5 | import Base: @propagate_inbounds |
6 | 6 |
|
7 | 7 | # Could be generalized if https://github.com/JuliaSparse/SparseArrays.jl/pull/546 is merged |
@@ -112,4 +112,51 @@ function _allocate_matrix(::Type{SparseMatrixCSR{1, Tv, Ti}}, sp::AbstractSparsi |
112 | 112 | return S |
113 | 113 | end |
114 | 114 |
|
| 115 | +## ================= ## |
| 116 | +# FastSparsityPattern # |
| 117 | +## ================= ## |
| 118 | + |
| 119 | +function _allocate_matrix(::Type{SparseMatrixCSR{1, Tv, Ti}}, sp::FastSparsityPattern{Ti}, sym::Bool) where {Tv, Ti} |
| 120 | + sym && throw(ArgumentError("FastSparsityPattern does not support symmetric matrices yet")) |
| 121 | + sp.is_colidx_sorted || sort_rows_threaded!(sp) # Require sorted rows |
| 122 | + nzval = zeros(Tv, length(sp.colidx)) |
| 123 | + return SparseMatrixCSR{1}(getnrows(sp), getncols(sp), sp.rowptr, sp.colidx, nzval) |
| 124 | +end |
| 125 | + |
| 126 | +function sort_rows!(sp) |
| 127 | + sort_rows!(sp, 1:getnrows(sp)) |
| 128 | + sp.is_colidx_sorted = true |
| 129 | + return sp |
| 130 | +end |
| 131 | + |
| 132 | +function sort_rows!(sp::FastSparsityPattern, rowrange::UnitRange) |
| 133 | + @inbounds for row in rowrange |
| 134 | + i1 = sp.rowptr[row] |
| 135 | + i2 = sp.rowptr[row + 1] - 1 |
| 136 | + if i1 < i2 |
| 137 | + sort!(view(sp.colidx, i1:i2)) |
| 138 | + end |
| 139 | + end |
| 140 | + return sp |
| 141 | +end |
| 142 | + |
| 143 | +function sort_rows_threaded!( |
| 144 | + sp::FastSparsityPattern, # Default ΔN ≥ 1000 and `n_tasks ≥ 1` |
| 145 | + ntasks = max(min(Threads.nthreads() * 100, getnrows(sp) ÷ 1000), 1) |
| 146 | + ) # Otherwise, 100 per thread for load balancing |
| 147 | + nrows = getnrows(sp) |
| 148 | + ΔN = nrows ÷ ntasks |
| 149 | + Base.Experimental.@sync begin |
| 150 | + for taskid in 1:ntasks |
| 151 | + Threads.@spawn begin |
| 152 | + first_idx = 1 + ΔN * (taskid - 1) |
| 153 | + last_idx = min(first_idx + ΔN - 1, nrows) |
| 154 | + sort_rows!(sp, first_idx:last_idx) |
| 155 | + end |
| 156 | + end |
| 157 | + end |
| 158 | + sp.is_colidx_sorted = true |
| 159 | + return sp |
| 160 | +end |
| 161 | + |
115 | 162 | end |
0 commit comments