Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

## Bug fixes

- `IDAKLUSolver` now records `Solution.closest_event_idx` after a SUNDIALS root return, so `BaseSolver.get_termination_reason` can short-circuit instead of re-walking every TERMINATION event's symbolic expression on the Python side. On a 1000-cycle SPM with `output_variables` set, cumulative allocations dropped 25% (~445 MB) and wall time 16%; the eliminated path was hot in long event-terminated cycling experiments.
- Fixed `Serialise.serialise_solver` silently dropping `root_method` (and any other nested `BaseSolver` `__init__` argument). After construction, `root_method` is a `BaseSolver` instance rather than the original string, so it failed `json.dumps` and was omitted from the config — making the deserialised solver fall back to the default. `to_config()` / `from_config()` now recurse into nested solver values, preserving `root_method` and its tolerances across the round-trip. ([#5497](https://github.com/pybamm-team/PyBaMM/pull/5497))
- Fixed `Serialise.serialise_experiment` silently dropping per-step `temperature` overrides. The JSON-config round-trip via `Experiment.to_config()` / `Experiment.from_config()` now preserves `temperature` for current, voltage, power, c-rate, and rest steps. ([#5496](https://github.com/pybamm-team/PyBaMM/pull/5496))
- Fixed `Serialise._to_json_safe` coercing Python `bool` values to `0`/`1` ints because `bool` is a subclass of `int`. `IDAKLUSolver.to_config()` now emits its bool options (`compile`, `print_stats`, `silence_sundials_errors`, etc.) as JSON `true`/`false` so they round-trip through strict-bool deserialisers. ([#5495](https://github.com/pybamm-team/PyBaMM/pull/5495))
Expand Down
36 changes: 30 additions & 6 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@

_UNSET = object()


def _flatten_inputs(inputs_dict):
"""Flatten ``{name: value}`` into a 1-D float array in dict-key order."""
if not inputs_dict:
return np.zeros(0)
return np.concatenate([np.asarray(v).reshape(-1) for v in inputs_dict.values()])


# Mirrors SUNDIALS ``IDA_ROOT_RETURN`` in ``sundials/include/ida/ida.h``.
# Returned by ``IDASolve`` (and surfaced via ``Solution.flag``) when the
# integrator has located one or more root function zeros.
Expand Down Expand Up @@ -546,6 +554,10 @@ def to_idaklu(fn):
self._setup[name] = fn
self._setup[f"{name}_pkl"] = pkl

# Used in _post_process_solution to set closest_event_idx without
# going via BaseSolver.get_termination_reason's per-event re-walk.
self._setup["rootfn_casadi"] = fns["rootfn"]

for key in self.output_variables:
fn, pkl = to_idaklu(fns[f"var:{key}"])
self._setup["var_idaklu_fcns"].append(fn)
Expand Down Expand Up @@ -603,6 +615,7 @@ def __getstate__(self):
"mass_action",
"sensfn",
"rootfn",
"rootfn_casadi",
"var_idaklu_fcns",
"dvar_dy_idaklu_fcns",
"dvar_dp_idaklu_fcns",
Expand Down Expand Up @@ -631,6 +644,10 @@ def __setstate__(self, d):
]:
self._setup[key] = idaklu.generate_function(self._setup[f"{key}_pkl"])

self._setup["rootfn_casadi"] = casadi.Function.deserialize(
self._setup["rootfn_pkl"]
)

for key in ["var_idaklu_fcns", "dvar_dy_idaklu_fcns", "dvar_dp_idaklu_fcns"]:
self._setup[key] = [
idaklu.generate_function(f) for f in self._setup[f"{key}_pkl"]
Expand Down Expand Up @@ -660,12 +677,7 @@ def _integrate(

# stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters)
if inputs_list and inputs_list[0]:
inputs = np.vstack(
[
np.hstack([np.array(x).reshape(-1) for x in inputs_dict.values()])
for inputs_dict in inputs_list
]
)
inputs = np.vstack([_flatten_inputs(d) for d in inputs_list])
else:
inputs = np.array([[]] * len(inputs_list))

Expand Down Expand Up @@ -806,6 +818,18 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict, t_ev
options=solution_options,
)

# Set closest_event_idx so BaseSolver.get_termination_reason doesn't
# re-walk every event's symbolic expression on the Python side.
if sol.flag == _IDA_ROOT_RETURN and self._setup["num_of_events"] > 0:
event_values = np.asarray(
self._setup["rootfn_casadi"](
float(sol.t[-1]),
np.asarray(y_event).reshape(-1),
_flatten_inputs(inputs_dict),
)
).reshape(-1)
newsol.closest_event_idx = int(np.nanargmin(np.abs(event_values)))

newsol.integration_time = integration_time
if not save_outputs_only:
return newsol
Expand Down
53 changes: 53 additions & 0 deletions tests/memory/test_memory_leaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,56 @@ def test_gitt_memory(self):

peak_mb = peak / 1024 / 1024
assert peak_mb < 6, f"Peak memory {peak_mb:.1f} MB for GITT is excessive."

def test_idaklu_event_termination_no_python_event_reeval(self):
"""
IDAKLUSolver with output_variables on a long event-terminated experiment
must not re-walk every TERMINATION event's symbolic expression on the
Python side after each step. That path goes through
StateVector._base_evaluate (per-leaf numpy fancy-indexing) and dominated
per-step allocation churn before the closest_event_idx fix.

Allocations are transient (immediately freed each step), so peak-RSS /
tracemalloc snapshots can't see them. Count direct calls to
_base_evaluate instead — deterministic across runs.
"""
original = pybamm.StateVector._base_evaluate
call_count = 0

def counting(self, *args, **kwargs):
nonlocal call_count
call_count += 1
return original(self, *args, **kwargs)

pybamm.StateVector._base_evaluate = counting
try:
experiment = pybamm.Experiment(
[
(
"Discharge at 1C until 3.0 V",
"Charge at 1C until 4.2 V",
"Hold at 4.2 V until C/50",
)
]
* 5,
period=300,
)
sim = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=experiment,
solver=pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]),
)
sim.solve()
finally:
pybamm.StateVector._base_evaluate = original

# Solve-time setup paths legitimately call _base_evaluate (initial
# conditions, event setup) — those scale with the model, not the
# cycle count. Per-step re-evaluation pushes the count past 1000 on
# a 5-cycle SPM run; the fix keeps it under 200.
assert call_count < 300, (
f"StateVector._base_evaluate was called {call_count} times during "
f"a 5-cycle solve. IDAKLUSolver should set closest_event_idx so "
f"BaseSolver.get_termination_reason short-circuits instead of "
f"re-walking event expressions on every step."
)
77 changes: 77 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,83 @@ def test_with_output_variables_and_event_termination(self):
sol3 = sim3.solve(np.linspace(0, 3600, 2))
assert sol3.termination == "event: Minimum voltage [V]"

def test_closest_event_idx_set_after_root_return(self):
# IDAKLU must populate Solution.closest_event_idx after a root return so
# BaseSolver.get_termination_reason short-circuits instead of re-walking
# every TERMINATION event's symbolic expression on the Python side. That
# slow path generated tens of thousands of small numpy allocations per
# long event-terminated cycling run.
cycle = (
"Discharge at 1C until 3.0 V",
"Charge at 1C until 4.2 V",
"Hold at 4.2 V until C/50",
)
sim = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=pybamm.Experiment([cycle] * 2, period=300),
solver=pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]),
)
sim.solve()

event_steps = [
step
for cycle_sol in sim.solution.cycles
for step in cycle_sol.steps
if step.termination.startswith("event:")
]
assert event_steps, "expected at least one event-terminated step"
# The index must also resolve to the same event name the slow path in
# BaseSolver.get_termination_reason would have picked.
for step in event_steps:
assert step.closest_event_idx is not None, (
f"event-terminated step {step.termination!r} has "
f"closest_event_idx=None — BaseSolver will fall back to "
f"per-step Python event re-evaluation"
)
terminate_events = [
e
for e in step.all_models[-1].events
if e.event_type == pybamm.EventType.TERMINATION
]
picked = terminate_events[step.closest_event_idx].name
assert step.termination == f"event: {picked}", (
f"closest_event_idx={step.closest_event_idx} resolves to "
f"{picked!r}, but step.termination is {step.termination!r}"
)

def test_pickle_roundtrip_preserves_closest_event_idx(self):
# rootfn_casadi is dropped from the pickle and rebuilt from rootfn_pkl
# in __setstate__; confirm a round-tripped solver still records
# closest_event_idx after a root-return termination.
import pickle

solver = pybamm.IDAKLUSolver(output_variables=["Voltage [V]"])
sim = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=pybamm.Experiment(
[("Discharge at 1C until 3.0 V", "Charge at 1C until 4.2 V")]
),
solver=solver,
)
sim.solve()

roundtripped = pickle.loads(pickle.dumps(solver))
sim2 = pybamm.Simulation(
pybamm.lithium_ion.SPM(),
experiment=pybamm.Experiment(
[("Discharge at 1C until 3.0 V", "Charge at 1C until 4.2 V")]
),
solver=roundtripped,
)
sim2.solve()

for step in sim2.solution.cycles[0].steps:
if step.termination.startswith("event:"):
assert step.closest_event_idx is not None, (
"round-tripped IDAKLUSolver must still set "
"closest_event_idx after a root return"
)

def test_simulation_period(self):
model = pybamm.lithium_ion.DFN()
parameter_values = pybamm.ParameterValues("Chen2020")
Expand Down
Loading