diff --git a/Project.toml b/Project.toml
index 2ee5fb9..0305388 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
authors = ["Patrick Altmeyer
"]
name = "TaijaParallel"
uuid = "bf1c2c22-5e42-4e78-8b6b-92e6c673eeb0"
-version = "1.1.3"
+version = "1.2.0"
[compat]
Aqua = "0.8"
diff --git a/ext/MPIExt/MPIExt.jl b/ext/MPIExt/MPIExt.jl
index 2668d79..f66721f 100644
--- a/ext/MPIExt/MPIExt.jl
+++ b/ext/MPIExt/MPIExt.jl
@@ -7,6 +7,7 @@ using MPI
using ProgressMeter
using TaijaBase
using TaijaParallel
+using TaijaParallel: load_with_retry
"The `MPIParallelizer` type is used to parallelize the evaluation of a function using `MPI.jl`."
struct MPIParallelizer <: TaijaParallel.AbstractParallelizer
diff --git a/ext/MPIExt/evaluate.jl b/ext/MPIExt/evaluate.jl
index c718883..e9a51d9 100644
--- a/ext/MPIExt/evaluate.jl
+++ b/ext/MPIExt/evaluate.jl
@@ -92,7 +92,7 @@ function TaijaBase.parallelize(
if parallelizer.rank == 0
outputs = []
for i = 1:length(chunks)
- output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
+ output = load_with_retry(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
# Collect output from all processes in rank 0:
diff --git a/ext/MPIExt/generate_counterfactual.jl b/ext/MPIExt/generate_counterfactual.jl
index fea7e34..031abd2 100644
--- a/ext/MPIExt/generate_counterfactual.jl
+++ b/ext/MPIExt/generate_counterfactual.jl
@@ -94,7 +94,7 @@ function TaijaBase.parallelize(
if parallelizer.rank == 0
outputs = []
for i = 1:length(chunks)
- output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
+ output = load_with_retry(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
# Collect output from all processes in rank 0:
diff --git a/src/utils.jl b/src/utils.jl
index 5fa8ad0..dbb1613 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1,3 +1,5 @@
+using Serialization
+
"""
chunk_obs(obs::AbstractVector, n_each::Integer, n_groups::Integer)
@@ -38,3 +40,23 @@ function split_obs(obs::AbstractVector, n::Integer)
N_counts = split_count(N, n)
return split_by_counts(obs, N_counts)
end
+
+"""
+ load_with_retry(filepath; max_attempts=5, delay=1.0)
+
+Load a file using Serialization.deserialize, retrying up to `max_attempts` times with exponential backoff.
+"""
+function load_with_retry(filepath; max_attempts=5, delay=1.0)
+ for attempt in 1:max_attempts
+ try
+ return Serialization.deserialize(filepath)
+ catch e
+ if isa(e, EOFError) && attempt < max_attempts
+ sleep(delay * attempt) # Exponential backoff
+ continue
+ end
+ rethrow(e) # Re-throw if it's not an EOFError or we're out of attempts
+ end
+ end
+ error("Failed to load $filepath after $max_attempts attempts")
+end