Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

## Bug fixes

- `IDAKLUSolver` now records `Solution.closest_event_idx` after a SUNDIALS root return, so `BaseSolver.get_termination_reason` can short-circuit instead of re-walking every TERMINATION event's symbolic expression on the Python side. On a 1000-cycle SPM with `output_variables` set, cumulative allocations dropped 25% (~445 MB) and wall time 16%; the eliminated path was hot in long event-terminated cycling experiments. ([#5502](https://github.com/pybamm-team/PyBaMM/pull/5502))
- Fixed `Serialise.serialise_experiment` / `deserialise_experiment` dropping every constructor argument other than per-step `value` / `duration` / `terminations` / `temperature`. The top-level `period`, `temperature`, and `termination` arguments to `pybamm.Experiment` and the per-step `period`, `tags`, `description`, `start_time`, `direction`, and `skip_ok` arguments to `BaseStep` are now written by `to_config()` and parsed back by `from_config()`, so JSON round-tripped experiments preserve user intent. The `Resistance` step type was also missing from the deserialiser's step-type map and now round-trips correctly.
- Fixed `Serialise.load_custom_model` reconstructing `Event.event_type` as the bare enum name string (e.g. `"TERMINATION"`) instead of the corresponding `EventType` member. The custom JSON encoder writes Enum values as their `.name`, but the loader was passing the string straight through to `pybamm.Event.__init__`, so models round-tripped through `to_json` / `from_json` carried string event types rather than `EventType` enum values. ([#5498](https://github.com/pybamm-team/PyBaMM/pull/5498))
- Fixed `Serialise.serialise_solver` silently dropping `root_method` (and any other nested `BaseSolver` `__init__` argument). After construction, `root_method` is a `BaseSolver` instance rather than the original string, so it failed `json.dumps` and was omitted from the config — making the deserialised solver fall back to the default. `to_config()` / `from_config()` now recurse into nested solver values, preserving `root_method` and its tolerances across the round-trip. ([#5497](https://github.com/pybamm-team/PyBaMM/pull/5497))
Expand Down
36 changes: 30 additions & 6 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@

_UNSET = object()


def _flatten_inputs(inputs_dict):
"""Flatten ``{name: value}`` into a 1-D float array in dict-key order."""
if not inputs_dict:
return np.zeros(0)
return np.concatenate([np.asarray(v).reshape(-1) for v in inputs_dict.values()])


# Mirrors SUNDIALS ``IDA_ROOT_RETURN`` in ``sundials/include/ida/ida.h``.
# Returned by ``IDASolve`` (and surfaced via ``Solution.flag``) when the
# integrator has located one or more root function zeros.
Expand Down Expand Up @@ -546,6 +554,10 @@ def to_idaklu(fn):
self._setup[name] = fn
self._setup[f"{name}_pkl"] = pkl

# Used in _post_process_solution to set closest_event_idx without
# going via BaseSolver.get_termination_reason's per-event re-walk.
self._setup["rootfn_casadi"] = fns["rootfn"]

for key in self.output_variables:
fn, pkl = to_idaklu(fns[f"var:{key}"])
self._setup["var_idaklu_fcns"].append(fn)
Expand Down Expand Up @@ -603,6 +615,7 @@ def __getstate__(self):
"mass_action",
"sensfn",
"rootfn",
"rootfn_casadi",
"var_idaklu_fcns",
"dvar_dy_idaklu_fcns",
"dvar_dp_idaklu_fcns",
Expand Down Expand Up @@ -631,6 +644,10 @@ def __setstate__(self, d):
]:
self._setup[key] = idaklu.generate_function(self._setup[f"{key}_pkl"])

self._setup["rootfn_casadi"] = casadi.Function.deserialize(
self._setup["rootfn_pkl"]
)

for key in ["var_idaklu_fcns", "dvar_dy_idaklu_fcns", "dvar_dp_idaklu_fcns"]:
self._setup[key] = [
idaklu.generate_function(f) for f in self._setup[f"{key}_pkl"]
Expand Down Expand Up @@ -660,12 +677,7 @@ def _integrate(

# stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters)
if inputs_list and inputs_list[0]:
inputs = np.vstack(
[
np.hstack([np.array(x).reshape(-1) for x in inputs_dict.values()])
for inputs_dict in inputs_list
]
)
inputs = np.vstack([_flatten_inputs(d) for d in inputs_list])
else:
inputs = np.array([[]] * len(inputs_list))

Expand Down Expand Up @@ -806,6 +818,18 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict, t_ev
options=solution_options,
)

# Set closest_event_idx so BaseSolver.get_termination_reason doesn't
# re-walk every event's symbolic expression on the Python side.
if sol.flag == _IDA_ROOT_RETURN and self._setup["num_of_events"] > 0:
event_values = np.asarray(
self._setup["rootfn_casadi"](
float(sol.t[-1]),
np.asarray(y_event).reshape(-1),
_flatten_inputs(inputs_dict),
)
).reshape(-1)
newsol.closest_event_idx = int(np.nanargmin(np.abs(event_values)))

newsol.integration_time = integration_time
if not save_outputs_only:
return newsol
Expand Down
53 changes: 53 additions & 0 deletions tests/memory/test_memory_leaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,56 @@ def test_gitt_memory(self):

peak_mb = peak / 1024 / 1024
assert peak_mb < 6, f"Peak memory {peak_mb:.1f} MB for GITT is excessive."

def test_idaklu_event_termination_no_python_event_reeval(self):
"""
IDAKLUSolver with output_variables on a long event-terminated experiment
must not re-walk every TERMINATION event's symbolic expression on the
Python side after each step. That path goes through
StateVector._base_evaluate (per-leaf numpy fancy-indexing) and dominated
per-step allocation churn before the closest_event_idx fix.

Allocations are transient (immediately freed each step), so peak-RSS /
tracemalloc snapshots can't see them. Count direct calls to
_base_evaluate instead — deterministic across runs.
"""
original = pybamm.StateVector._base_evaluate
call_count = 0

def counting(self, *args, **kwargs):
nonlocal call_count
call_count += 1
return original(self, *args, **kwargs)

pybamm.StateVector._base_evaluate = counting
try:
experiment = pybamm.Experiment(
[
(
"Discharge at 1C until 3.0 V",
"Charge at 1C until 4.2 V",
"Hold at 4.2 V until C/50",
)
]
* 5,
period=300,
)
sim = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=experiment,
solver=pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]),
)
sim.solve()
finally:
pybamm.StateVector._base_evaluate = original

# Solve-time setup paths legitimately call _base_evaluate (initial
# conditions, event setup) — those scale with the model, not the
# cycle count. Per-step re-evaluation pushes the count past 1000 on
# a 5-cycle SPM run; the fix keeps it under 200.
assert call_count < 300, (
f"StateVector._base_evaluate was called {call_count} times during "
f"a 5-cycle solve. IDAKLUSolver should set closest_event_idx so "
f"BaseSolver.get_termination_reason short-circuits instead of "
f"re-walking event expressions on every step."
)
77 changes: 77 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,83 @@ def test_with_output_variables_and_event_termination(self):
sol3 = sim3.solve(np.linspace(0, 3600, 2))
assert sol3.termination == "event: Minimum voltage [V]"

def test_closest_event_idx_set_after_root_return(self):
# IDAKLU must populate Solution.closest_event_idx after a root return so
# BaseSolver.get_termination_reason short-circuits instead of re-walking
# every TERMINATION event's symbolic expression on the Python side. That
# slow path generated tens of thousands of small numpy allocations per
# long event-terminated cycling run.
cycle = (
"Discharge at 1C until 3.0 V",
"Charge at 1C until 4.2 V",
"Hold at 4.2 V until C/50",
)
sim = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=pybamm.Experiment([cycle] * 2, period=300),
solver=pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]),
)
sim.solve()

event_steps = [
step
for cycle_sol in sim.solution.cycles
for step in cycle_sol.steps
if step.termination.startswith("event:")
]
assert event_steps, "expected at least one event-terminated step"
# The index must also resolve to the same event name the slow path in
# BaseSolver.get_termination_reason would have picked.
for step in event_steps:
assert step.closest_event_idx is not None, (
f"event-terminated step {step.termination!r} has "
f"closest_event_idx=None — BaseSolver will fall back to "
f"per-step Python event re-evaluation"
)
terminate_events = [
e
for e in step.all_models[-1].events
if e.event_type == pybamm.EventType.TERMINATION
]
picked = terminate_events[step.closest_event_idx].name
assert step.termination == f"event: {picked}", (
f"closest_event_idx={step.closest_event_idx} resolves to "
f"{picked!r}, but step.termination is {step.termination!r}"
)

def test_pickle_roundtrip_preserves_closest_event_idx(self):
# rootfn_casadi is dropped from the pickle and rebuilt from rootfn_pkl
# in __setstate__; confirm a round-tripped solver still records
# closest_event_idx after a root-return termination.
import pickle

solver = pybamm.IDAKLUSolver(output_variables=["Voltage [V]"])
sim = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=pybamm.Experiment(
[("Discharge at 1C until 3.0 V", "Charge at 1C until 4.2 V")]
),
solver=solver,
)
sim.solve()

roundtripped = pickle.loads(pickle.dumps(solver))
sim2 = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=pybamm.Experiment(
[("Discharge at 1C until 3.0 V", "Charge at 1C until 4.2 V")]
),
solver=roundtripped,
)
sim2.solve()

for step in sim2.solution.cycles[0].steps:
if step.termination.startswith("event:"):
assert step.closest_event_idx is not None, (
"round-tripped IDAKLUSolver must still set "
"closest_event_idx after a root return"
)

def test_simulation_period(self):
model = pybamm.lithium_ion.DFN()
parameter_values = pybamm.ParameterValues("Chen2020")
Expand Down
Loading