diff --git a/ext/MPIExt/MPIExt.jl b/ext/MPIExt/MPIExt.jl index 2668d79..8914a61 100644 --- a/ext/MPIExt/MPIExt.jl +++ b/ext/MPIExt/MPIExt.jl @@ -15,6 +15,7 @@ struct MPIParallelizer <: TaijaParallel.AbstractParallelizer n_proc::Int n_each::Union{Nothing,Int} threaded::Bool + active_comm::MPI.Comm end """ @@ -26,9 +27,11 @@ function TaijaParallel.MPIParallelizer( comm::MPI.Comm; n_each::Union{Nothing,Int} = nothing, threaded::Bool = false, + active_comm::Union{Nothing,MPI.Comm} = comm ) - rank = MPI.Comm_rank(comm) # Rank of this process in the world 🌍 - n_proc = MPI.Comm_size(comm) # Number of processes in the world 🌍 + rank = MPI.Comm_rank(comm) # Rank of this process in the world 🌍 + n_proc = MPI.Comm_size(comm) # Number of processes in the world 🌍 + active_comm = isnothing(active_comm) ? comm : active_comm # Active communication channel (if specified) if rank == 0 @info "Using `MPI.jl` for multi-processing." @@ -42,7 +45,7 @@ function TaijaParallel.MPIParallelizer( end end - return MPIParallelizer(comm, rank, n_proc, n_each, threaded) + return MPIParallelizer(comm, rank, n_proc, n_each, threaded, active_comm) end include("generate_counterfactual.jl") diff --git a/ext/MPIExt/evaluate.jl b/ext/MPIExt/evaluate.jl index 1bbd296..fb9051f 100644 --- a/ext/MPIExt/evaluate.jl +++ b/ext/MPIExt/evaluate.jl @@ -73,7 +73,7 @@ function TaijaBase.parallelize( kwargs..., ) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Collect output from all processe in rank 0: collected_output = MPI.gather(output, parallelizer.comm) @@ -81,11 +81,11 @@ function TaijaBase.parallelize( output = vcat(collected_output...) Serialization.serialize(joinpath(storage_path, "output_$i.jls"), output) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) end # Collect all chunks in rank 0: - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Load output from rank 0: if parallelizer.rank == 0 @@ -101,8 +101,8 @@ function TaijaBase.parallelize( end # Broadcast output to all processes: - final_output = MPI.bcast(output, parallelizer.comm; root = 0) - MPI.Barrier(parallelizer.comm) + final_output = MPI.bcast(output, parallelizer.active_comm; root = 0) + MPI.Barrier(parallelizer.active_comm) return final_output end diff --git a/ext/MPIExt/generate_counterfactual.jl b/ext/MPIExt/generate_counterfactual.jl index 95fc51b..b1eb015 100644 --- a/ext/MPIExt/generate_counterfactual.jl +++ b/ext/MPIExt/generate_counterfactual.jl @@ -73,7 +73,7 @@ function TaijaBase.parallelize( kwargs..., ) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Collect output from all processe in rank 0: collected_output = MPI.gather(output, parallelizer.comm) @@ -81,11 +81,11 @@ function TaijaBase.parallelize( output = vcat(collected_output...) Serialization.serialize(joinpath(storage_path, "output_$i.jls"), output) end - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) end # Collect all chunks in rank 0: - MPI.Barrier(parallelizer.comm) + MPI.Barrier(parallelizer.active_comm) # Load output from rank 0: if parallelizer.rank == 0 @@ -101,8 +101,8 @@ function TaijaBase.parallelize( end # Broadcast output to all processes: - final_output = MPI.bcast(output, parallelizer.comm; root = 0) - MPI.Barrier(parallelizer.comm) + final_output = MPI.bcast(output, parallelizer.active_comm; root = 0) + MPI.Barrier(parallelizer.active_comm) return final_output end diff --git a/src/extensions/MPIExt.jl b/src/extensions/MPIExt.jl index 0973f47..0858eca 100644 --- a/src/extensions/MPIExt.jl +++ b/src/extensions/MPIExt.jl @@ -5,3 +5,12 @@ Exposes the `MPIParallelizer` function from the `MPIExt` extension. """ function MPIParallelizer end export MPIParallelizer + +global _active_comm = nothing + +function set_active_comm(comm) + global _active_comm = comm + return _active_comm +end + +get_active_comm() = _active_comm \ No newline at end of file