Skip to content

Commit 69f8150

Browse files
committed
avoid using threadid in landau example, instead use OhMyThreads + ChunkSplitters
1 parent 236eb50 commit 69f8150

1 file changed

Lines changed: 61 additions & 57 deletions

File tree

docs/src/literate-gallery/landau.jl

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using Ferrite
2626
using Optim, LineSearches
2727
using SparseArrays
2828
using Tensors
29+
using OhMyThreads, ChunkSplitters
2930

3031
# ## Energy terms
3132
# ### 4th order Landau free energy
@@ -47,7 +48,7 @@ struct ModelParams{V, T}
4748
end
4849

4950
# ### ThreadCache
50-
# This holds the values that each thread will use during the assembly.
51+
# This holds the values that each task will use during the assembly.
5152
struct ThreadCache{CV, T, DIM, F <: Function, GC <: GradientConfig, HC <: HessianConfig}
5253
cvP::CV
5354
element_indices::Vector{Int}
@@ -72,16 +73,17 @@ function ThreadCache(dpc::Int, nodespercell, cvP::CellValues, modelparams, elpot
7273
end
7374

7475
# ## The Model
75-
# everything is combined into a model.
76+
# Everything is combined into a model. The caches are pre-allocated (one per task)
77+
# and indexed by chunk index during assembly.
7678
mutable struct LandauModel{T, DH <: DofHandler, CH <: ConstraintHandler, TC <: ThreadCache}
7779
dofs::Vector{T}
7880
dofhandler::DH
7981
boundaryconds::CH
8082
threadindices::Vector{Vector{Int}}
81-
threadcaches::Vector{TC}
83+
caches::Vector{TC}
8284
end
8385

84-
function LandauModel(α, G, gridsize, left::Vec{DIM, T}, right::Vec{DIM, T}, elpotential) where {DIM, T}
86+
function LandauModel(α, G, gridsize, left::Vec{DIM, T}, right::Vec{DIM, T}, elpotential, ntasks) where {DIM, T}
8587
grid = generate_grid(Tetrahedron, gridsize, left, right)
8688
threadindices = Ferrite.create_coloring(grid)
8789

@@ -106,7 +108,7 @@ function LandauModel(α, G, gridsize, left::Vec{DIM, T}, right::Vec{DIM, T}, elp
106108

107109
dpc = ndofs_per_cell(dofhandler)
108110
cpc = length(grid.cells[1].nodes)
109-
caches = [ThreadCache(dpc, cpc, copy(cvP), ModelParams(α, G), elpotential) for t in 1:Threads.maxthreadid()]
111+
caches = [ThreadCache(dpc, cpc, copy(cvP), ModelParams(α, G), elpotential) for _ in 1:ntasks]
110112
return LandauModel(dofvector, dofhandler, boundaryconds, threadindices, caches)
111113
end
112114

@@ -119,75 +121,77 @@ function save_landau(path, model, dofs = model.dofs)
119121
end
120122

121123
# ## Assembly
122-
# This macro defines most of the assembly step, since the structure is the same for
123-
# the energy, gradient and Hessian calculations.
124-
macro assemble!(innerbody)
125-
return esc(
126-
quote
127-
dofhandler = model.dofhandler
128-
for indices in model.threadindices
129-
Threads.@threads for i in indices
130-
cache = model.threadcaches[Threads.threadid()]
131-
eldofs = cache.element_dofs
132-
nodeids = dofhandler.grid.cells[i].nodes
133-
for j in 1:length(cache.element_coords)
134-
cache.element_coords[j] = dofhandler.grid.nodes[nodeids[j]].x
135-
end
136-
reinit!(cache.cvP, cache.element_coords)
137-
138-
celldofs!(cache.element_indices, dofhandler, i)
139-
for j in 1:length(cache.element_dofs)
140-
eldofs[j] = dofvector[cache.element_indices[j]]
141-
end
142-
$innerbody
143-
end
144-
end
145-
end
146-
)
124+
# This helper sets up the cell data in the cache for a given cell index,
125+
# and returns the element dof values.
126+
function setup_cell!(cache, dofhandler, dofvector, cellidx)
127+
nodeids = dofhandler.grid.cells[cellidx].nodes
128+
for j in 1:length(cache.element_coords)
129+
cache.element_coords[j] = dofhandler.grid.nodes[nodeids[j]].x
130+
end
131+
reinit!(cache.cvP, cache.element_coords)
132+
celldofs!(cache.element_indices, dofhandler, cellidx)
133+
eldofs = cache.element_dofs
134+
for j in 1:length(eldofs)
135+
eldofs[j] = dofvector[cache.element_indices[j]]
136+
end
137+
return eldofs
147138
end
148139

149-
# This calculates the total energy calculation of the grid
140+
# This calculates the total energy of the grid.
150141
function F(dofvector::Vector{T}, model) where {T}
151-
outs = fill(zero(T), Threads.maxthreadid())
152-
@assemble! begin
153-
outs[Threads.threadid()] += cache.element_potential(eldofs)
142+
out = zero(T)
143+
for indices in model.threadindices
144+
partial = OhMyThreads.@tasks for (ichunk, range) in enumerate(chunks(indices; n = length(model.caches)))
145+
@set reducer = +
146+
cache = model.caches[ichunk]
147+
local_energy = zero(T)
148+
for i in range
149+
eldofs = setup_cell!(cache, model.dofhandler, dofvector, i)
150+
local_energy += cache.element_potential(eldofs)
151+
end
152+
local_energy
153+
end
154+
out += partial
154155
end
155-
return sum(outs)
156+
return out
156157
end
157158

158-
# The gradient calculation for each dof
159+
# The gradient calculation for each dof.
160+
# The grid coloring ensures no two tasks within a color share dofs,
161+
# so assembly is safe without locks.
159162
function ∇F!(∇f::Vector{T}, dofvector::Vector{T}, model::LandauModel{T}) where {T}
160163
fill!(∇f, zero(T))
161-
@assemble! begin
162-
ForwardDiff.gradient!(cache.element_gradient, cache.element_potential, eldofs, cache.gradconf)
163-
@inbounds assemble!(∇f, cache.element_indices, cache.element_gradient)
164+
for indices in model.threadindices
165+
OhMyThreads.@tasks for (ichunk, range) in enumerate(chunks(indices; n = length(model.caches)))
166+
cache = model.caches[ichunk]
167+
for i in range
168+
eldofs = setup_cell!(cache, model.dofhandler, dofvector, i)
169+
ForwardDiff.gradient!(cache.element_gradient, cache.element_potential, eldofs, cache.gradconf)
170+
@inbounds assemble!(∇f, cache.element_indices, cache.element_gradient)
171+
end
172+
end
164173
end
165174
return
166175
end
167176

168177
# The Hessian calculation for the whole grid
169178
function ∇²F!(∇²f::SparseMatrixCSC, dofvector::Vector{T}, model::LandauModel{T}) where {T}
170-
assemblers = [start_assemble(∇²f) for t in 1:Threads.maxthreadid()]
171-
@assemble! begin
172-
ForwardDiff.hessian!(cache.element_hessian, cache.element_potential, eldofs, cache.hessconf)
173-
@inbounds assemble!(assemblers[Threads.threadid()], cache.element_indices, cache.element_hessian)
179+
dh = model.dofhandler
180+
ntasks = length(model.caches)
181+
assemblers = [start_assemble(∇²f; fillzero = (i == 1)) for i in 1:ntasks]
182+
for indices in model.threadindices
183+
OhMyThreads.@tasks for (ichunk, range) in enumerate(chunks(indices; n = ntasks))
184+
cache = model.caches[ichunk]
185+
for i in range
186+
eldofs = setup_cell!(cache, dh, dofvector, i)
187+
ForwardDiff.hessian!(cache.element_hessian, cache.element_potential, eldofs, cache.hessconf)
188+
@inbounds assemble!(assemblers[ichunk], cache.element_indices, cache.element_hessian)
189+
end
190+
end
174191
end
175192
return
176193
end
177194

178-
# We can also calculate all things in one go!
179-
function calcall(∇²f::SparseMatrixCSC, ∇f::Vector{T}, dofvector::Vector{T}, model::LandauModel{T}) where {T}
180-
outs = fill(zero(T), Threads.maxthreadid())
181-
fill!(∇f, zero(T))
182-
assemblers = [start_assemble(∇²f, ∇f) for t in 1:Threads.maxthreadid()]
183-
@assemble! begin
184-
outs[Threads.threadid()] += cache.element_potential(eldofs)
185-
ForwardDiff.hessian!(cache.element_hessian, cache.element_potential, eldofs, cache.hessconf)
186-
ForwardDiff.gradient!(cache.element_gradient, cache.element_potential, eldofs, cache.gradconf)
187-
@inbounds assemble!(assemblers[Threads.threadid()], cache.element_indices, cache.element_gradient, cache.element_hessian)
188-
end
189-
return sum(outs)
190-
end
191195

192196
# ## Minimization
193197
# Now everything can be combined to minimize the energy, and find the equilibrium
@@ -255,7 +259,7 @@ G = V2T(1.0e2, 0.0, 1.0e2)
255259
α = Vec{3}((-1.0, 1.0, 1.0))
256260
left = Vec{3}((-75.0, -25.0, -2.0))
257261
right = Vec{3}((75.0, 25.0, 2.0))
258-
model = LandauModel(α, G, (50, 50, 2), left, right, element_potential)
262+
model = LandauModel(α, G, (50, 50, 2), left, right, element_potential, Threads.nthreads())
259263

260264
save_landau("landauorig", model)
261265
@time minimize!(model)

0 commit comments

Comments
 (0)