diff --git a/Project.toml b/Project.toml index 2ee5fb9..0305388 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ authors = ["Patrick Altmeyer "] name = "TaijaParallel" uuid = "bf1c2c22-5e42-4e78-8b6b-92e6c673eeb0" -version = "1.1.3" +version = "1.2.0" [compat] Aqua = "0.8" diff --git a/ext/MPIExt/MPIExt.jl b/ext/MPIExt/MPIExt.jl index 2668d79..f66721f 100644 --- a/ext/MPIExt/MPIExt.jl +++ b/ext/MPIExt/MPIExt.jl @@ -7,6 +7,7 @@ using MPI using ProgressMeter using TaijaBase using TaijaParallel +using TaijaParallel: load_with_retry "The `MPIParallelizer` type is used to parallelize the evaluation of a function using `MPI.jl`." struct MPIParallelizer <: TaijaParallel.AbstractParallelizer diff --git a/ext/MPIExt/evaluate.jl b/ext/MPIExt/evaluate.jl index c718883..e9a51d9 100644 --- a/ext/MPIExt/evaluate.jl +++ b/ext/MPIExt/evaluate.jl @@ -92,7 +92,7 @@ function TaijaBase.parallelize( if parallelizer.rank == 0 outputs = [] for i = 1:length(chunks) - output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls")) + output = load_with_retry(joinpath(storage_path, "output_$i.jls")) push!(outputs, output) end # Collect output from all processes in rank 0: diff --git a/ext/MPIExt/generate_counterfactual.jl b/ext/MPIExt/generate_counterfactual.jl index fea7e34..031abd2 100644 --- a/ext/MPIExt/generate_counterfactual.jl +++ b/ext/MPIExt/generate_counterfactual.jl @@ -94,7 +94,7 @@ function TaijaBase.parallelize( if parallelizer.rank == 0 outputs = [] for i = 1:length(chunks) - output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls")) + output = load_with_retry(joinpath(storage_path, "output_$i.jls")) push!(outputs, output) end # Collect output from all processes in rank 0: diff --git a/src/utils.jl b/src/utils.jl index 5fa8ad0..dbb1613 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,5 @@ +using Serialization + """ chunk_obs(obs::AbstractVector, n_each::Integer, n_groups::Integer) @@ -38,3 +40,23 @@ function split_obs(obs::AbstractVector, n::Integer) N_counts = split_count(N, n) return split_by_counts(obs, N_counts) end + +""" + load_with_retry(filepath; max_attempts=5, delay=1.0) + +Load a file using Serialization.deserialize, retrying up to `max_attempts` times with exponential backoff. +""" +function load_with_retry(filepath; max_attempts=5, delay=1.0) + for attempt in 1:max_attempts + try + return Serialization.deserialize(filepath) + catch e + if isa(e, EOFError) && attempt < max_attempts + sleep(delay * attempt) # Exponential backoff + continue + end + rethrow(e) # Re-throw if it's not an EOFError or we're out of attempts + end + end + error("Failed to load $filepath after $max_attempts attempts") +end