Skip to content

Commit c98643a

Browse files
committed
avoid using threadid in landau example, instead use OhMyThreads + ChunkSplitters
1 parent 236eb50 commit c98643a

3 files changed

Lines changed: 73 additions & 68 deletions

File tree

docs/Manifest.toml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.12.1"
3+
julia_version = "1.12.5"
44
manifest_format = "2.0"
5-
project_hash = "765feb17a5ba23a96b16cd2bae96223ad0618954"
5+
project_hash = "6d0dc5950af997fe06d71214d77ad65f439e8f72"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "27cecae79e5cc9935255f90c53bb831cc3c870d7"
@@ -231,9 +231,9 @@ uuid = "5217a498-cd5d-4ec6-b8c2-9b85a09b6e3e"
231231
version = "1.1.0"
232232

233233
[[deps.ChunkSplitters]]
234-
git-tree-sha1 = "63a3903063d035260f0f6eab00f517471c5dc784"
234+
git-tree-sha1 = "1c52c8e2673edc030191177ff1aee42d25149acb"
235235
uuid = "ae650224-84b6-46f8-82ea-d812ca08434e"
236-
version = "3.1.2"
236+
version = "3.2.0"
237237

238238
[[deps.CloseOpenIntervals]]
239239
deps = ["Static", "StaticArrayInterface"]
@@ -521,7 +521,7 @@ version = "1.4.1"
521521
[[deps.Downloads]]
522522
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
523523
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
524-
version = "1.6.0"
524+
version = "1.7.0"
525525

526526
[[deps.EnumX]]
527527
git-tree-sha1 = "bddad79635af6aec424f53ed8aad5d7555dc6f00"
@@ -1066,7 +1066,7 @@ version = "0.6.4"
10661066
[[deps.LibCURL_jll]]
10671067
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll", "Zlib_jll", "nghttp2_jll"]
10681068
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
1069-
version = "8.11.1+1"
1069+
version = "8.15.0+0"
10701070

10711071
[[deps.LibGit2]]
10721072
deps = ["LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
@@ -1358,7 +1358,7 @@ version = "0.3.7"
13581358

13591359
[[deps.MozillaCACerts_jll]]
13601360
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
1361-
version = "2025.5.20"
1361+
version = "2025.11.4"
13621362

13631363
[[deps.MuladdMacro]]
13641364
git-tree-sha1 = "cac9cc5499c25554cba55cd3c30543cff5ca4fab"
@@ -1440,7 +1440,7 @@ version = "1.6.0"
14401440
[[deps.OpenSSL_jll]]
14411441
deps = ["Artifacts", "Libdl"]
14421442
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
1443-
version = "3.5.1+0"
1443+
version = "3.5.4+0"
14441444

14451445
[[deps.OpenSpecFun_jll]]
14461446
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"]
@@ -1533,7 +1533,7 @@ version = "0.44.2+0"
15331533
[[deps.Pkg]]
15341534
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"]
15351535
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1536-
version = "1.12.0"
1536+
version = "1.12.1"
15371537
weakdeps = ["REPL"]
15381538

15391539
[deps.Pkg.extensions]
@@ -2419,9 +2419,9 @@ uuid = "1317d2d5-d96f-522e-a858-c73665f53c3e"
24192419
version = "2022.0.0+1"
24202420

24212421
[[deps.p7zip_jll]]
2422-
deps = ["Artifacts", "Libdl"]
2422+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
24232423
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
2424-
version = "17.5.0+2"
2424+
version = "17.7.0+0"
24252425

24262426
[[deps.x264_jll]]
24272427
deps = ["Artifacts", "JLLWrappers", "Libdl"]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Bibliography = "f1be7e48-bf82-45af-a471-ae754a193061"
33
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
44
Changelog = "5217a498-cd5d-4ec6-b8c2-9b85a09b6e3e"
5+
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
56
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
67
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
78
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"

docs/src/literate-gallery/landau.jl

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using Ferrite
2626
using Optim, LineSearches
2727
using SparseArrays
2828
using Tensors
29+
using OhMyThreads, ChunkSplitters
2930

3031
# ## Energy terms
3132
# ### 4th order Landau free energy
@@ -47,7 +48,7 @@ struct ModelParams{V, T}
4748
end
4849

4950
# ### ThreadCache
50-
# This holds the values that each thread will use during the assembly.
51+
# This holds the values that each task will use during the assembly.
5152
struct ThreadCache{CV, T, DIM, F <: Function, GC <: GradientConfig, HC <: HessianConfig}
5253
cvP::CV
5354
element_indices::Vector{Int}
@@ -72,16 +73,17 @@ function ThreadCache(dpc::Int, nodespercell, cvP::CellValues, modelparams, elpot
7273
end
7374

7475
# ## The Model
75-
# everything is combined into a model.
76+
# Everything is combined into a model. The caches are pre-allocated (one per task)
77+
# and indexed by chunk index during assembly.
7678
mutable struct LandauModel{T, DH <: DofHandler, CH <: ConstraintHandler, TC <: ThreadCache}
7779
dofs::Vector{T}
7880
dofhandler::DH
7981
boundaryconds::CH
8082
threadindices::Vector{Vector{Int}}
81-
threadcaches::Vector{TC}
83+
caches::Vector{TC}
8284
end
8385

84-
function LandauModel(α, G, gridsize, left::Vec{DIM, T}, right::Vec{DIM, T}, elpotential) where {DIM, T}
86+
function LandauModel(α, G, gridsize, left::Vec{DIM, T}, right::Vec{DIM, T}, elpotential, ntasks) where {DIM, T}
8587
grid = generate_grid(Tetrahedron, gridsize, left, right)
8688
threadindices = Ferrite.create_coloring(grid)
8789

@@ -106,7 +108,7 @@ function LandauModel(α, G, gridsize, left::Vec{DIM, T}, right::Vec{DIM, T}, elp
106108

107109
dpc = ndofs_per_cell(dofhandler)
108110
cpc = length(grid.cells[1].nodes)
109-
caches = [ThreadCache(dpc, cpc, copy(cvP), ModelParams(α, G), elpotential) for t in 1:Threads.maxthreadid()]
111+
caches = [ThreadCache(dpc, cpc, copy(cvP), ModelParams(α, G), elpotential) for _ in 1:ntasks]
110112
return LandauModel(dofvector, dofhandler, boundaryconds, threadindices, caches)
111113
end
112114

@@ -119,75 +121,77 @@ function save_landau(path, model, dofs = model.dofs)
119121
end
120122

121123
# ## Assembly
122-
# This macro defines most of the assembly step, since the structure is the same for
123-
# the energy, gradient and Hessian calculations.
124-
macro assemble!(innerbody)
125-
return esc(
126-
quote
127-
dofhandler = model.dofhandler
128-
for indices in model.threadindices
129-
Threads.@threads for i in indices
130-
cache = model.threadcaches[Threads.threadid()]
131-
eldofs = cache.element_dofs
132-
nodeids = dofhandler.grid.cells[i].nodes
133-
for j in 1:length(cache.element_coords)
134-
cache.element_coords[j] = dofhandler.grid.nodes[nodeids[j]].x
135-
end
136-
reinit!(cache.cvP, cache.element_coords)
137-
138-
celldofs!(cache.element_indices, dofhandler, i)
139-
for j in 1:length(cache.element_dofs)
140-
eldofs[j] = dofvector[cache.element_indices[j]]
141-
end
142-
$innerbody
143-
end
144-
end
145-
end
146-
)
124+
# This helper sets up the cell data in the cache for a given cell index,
125+
# and returns the element dof values.
126+
function setup_cell!(cache, dofhandler, dofvector, cellidx)
127+
nodeids = dofhandler.grid.cells[cellidx].nodes
128+
for j in 1:length(cache.element_coords)
129+
cache.element_coords[j] = dofhandler.grid.nodes[nodeids[j]].x
130+
end
131+
reinit!(cache.cvP, cache.element_coords)
132+
celldofs!(cache.element_indices, dofhandler, cellidx)
133+
eldofs = cache.element_dofs
134+
for j in 1:length(eldofs)
135+
eldofs[j] = dofvector[cache.element_indices[j]]
136+
end
137+
return eldofs
147138
end
148139

149-
# This calculates the total energy calculation of the grid
140+
# This calculates the total energy of the grid.
150141
function F(dofvector::Vector{T}, model) where {T}
151-
outs = fill(zero(T), Threads.maxthreadid())
152-
@assemble! begin
153-
outs[Threads.threadid()] += cache.element_potential(eldofs)
142+
out = zero(T)
143+
for indices in model.threadindices
144+
partial = OhMyThreads.@tasks for (ichunk, range) in enumerate(chunks(indices; n = length(model.caches)))
145+
@set reducer = +
146+
cache = model.caches[ichunk]
147+
local_energy = zero(T)
148+
for i in range
149+
eldofs = setup_cell!(cache, model.dofhandler, dofvector, i)
150+
local_energy += cache.element_potential(eldofs)
151+
end
152+
local_energy
153+
end
154+
out += partial
154155
end
155-
return sum(outs)
156+
return out
156157
end
157158

158-
# The gradient calculation for each dof
159+
# The gradient calculation for each dof.
160+
# The grid coloring ensures no two tasks within a color share dofs,
161+
# so assembly is safe without locks.
159162
function ∇F!(∇f::Vector{T}, dofvector::Vector{T}, model::LandauModel{T}) where {T}
160163
fill!(∇f, zero(T))
161-
@assemble! begin
162-
ForwardDiff.gradient!(cache.element_gradient, cache.element_potential, eldofs, cache.gradconf)
163-
@inbounds assemble!(∇f, cache.element_indices, cache.element_gradient)
164+
for indices in model.threadindices
165+
OhMyThreads.@tasks for (ichunk, range) in enumerate(chunks(indices; n = length(model.caches)))
166+
cache = model.caches[ichunk]
167+
for i in range
168+
eldofs = setup_cell!(cache, model.dofhandler, dofvector, i)
169+
ForwardDiff.gradient!(cache.element_gradient, cache.element_potential, eldofs, cache.gradconf)
170+
@inbounds assemble!(∇f, cache.element_indices, cache.element_gradient)
171+
end
172+
end
164173
end
165174
return
166175
end
167176

168177
# The Hessian calculation for the whole grid
169178
function ∇²F!(∇²f::SparseMatrixCSC, dofvector::Vector{T}, model::LandauModel{T}) where {T}
170-
assemblers = [start_assemble(∇²f) for t in 1:Threads.maxthreadid()]
171-
@assemble! begin
172-
ForwardDiff.hessian!(cache.element_hessian, cache.element_potential, eldofs, cache.hessconf)
173-
@inbounds assemble!(assemblers[Threads.threadid()], cache.element_indices, cache.element_hessian)
179+
dh = model.dofhandler
180+
ntasks = length(model.caches)
181+
assemblers = [start_assemble(∇²f; fillzero = (i == 1)) for i in 1:ntasks]
182+
for indices in model.threadindices
183+
OhMyThreads.@tasks for (ichunk, range) in enumerate(chunks(indices; n = ntasks))
184+
cache = model.caches[ichunk]
185+
for i in range
186+
eldofs = setup_cell!(cache, dh, dofvector, i)
187+
ForwardDiff.hessian!(cache.element_hessian, cache.element_potential, eldofs, cache.hessconf)
188+
@inbounds assemble!(assemblers[ichunk], cache.element_indices, cache.element_hessian)
189+
end
190+
end
174191
end
175192
return
176193
end
177194

178-
# We can also calculate all things in one go!
179-
function calcall(∇²f::SparseMatrixCSC, ∇f::Vector{T}, dofvector::Vector{T}, model::LandauModel{T}) where {T}
180-
outs = fill(zero(T), Threads.maxthreadid())
181-
fill!(∇f, zero(T))
182-
assemblers = [start_assemble(∇²f, ∇f) for t in 1:Threads.maxthreadid()]
183-
@assemble! begin
184-
outs[Threads.threadid()] += cache.element_potential(eldofs)
185-
ForwardDiff.hessian!(cache.element_hessian, cache.element_potential, eldofs, cache.hessconf)
186-
ForwardDiff.gradient!(cache.element_gradient, cache.element_potential, eldofs, cache.gradconf)
187-
@inbounds assemble!(assemblers[Threads.threadid()], cache.element_indices, cache.element_gradient, cache.element_hessian)
188-
end
189-
return sum(outs)
190-
end
191195

192196
# ## Minimization
193197
# Now everything can be combined to minimize the energy, and find the equilibrium
@@ -255,7 +259,7 @@ G = V2T(1.0e2, 0.0, 1.0e2)
255259
α = Vec{3}((-1.0, 1.0, 1.0))
256260
left = Vec{3}((-75.0, -25.0, -2.0))
257261
right = Vec{3}((75.0, 25.0, 2.0))
258-
model = LandauModel(α, G, (50, 50, 2), left, right, element_potential)
262+
model = LandauModel(α, G, (50, 50, 2), left, right, element_potential, Threads.nthreads())
259263

260264
save_landau("landauorig", model)
261265
@time minimize!(model)

0 commit comments

Comments
 (0)