diff --git a/CHANGELOG.md b/CHANGELOG.md index 48ed00ecfa..cf262a4530 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ ## Bug fixes +- `IDAKLUSolver` now records `Solution.closest_event_idx` after a SUNDIALS root return, so `BaseSolver.get_termination_reason` can short-circuit instead of re-walking every TERMINATION event's symbolic expression on the Python side. On a 1000-cycle SPM with `output_variables` set, cumulative allocations dropped 25% (~445 MB) and wall time 16%; the eliminated path was hot in long event-terminated cycling experiments. ([#5502](https://github.com/pybamm-team/PyBaMM/pull/5502)) - Fixed `Serialise.serialise_experiment` / `deserialise_experiment` dropping every constructor argument other than per-step `value` / `duration` / `terminations` / `temperature`. The top-level `period`, `temperature`, and `termination` arguments to `pybamm.Experiment` and the per-step `period`, `tags`, `description`, `start_time`, `direction`, and `skip_ok` arguments to `BaseStep` are now written by `to_config()` and parsed back by `from_config()`, so JSON round-tripped experiments preserve user intent. The `Resistance` step type was also missing from the deserialiser's step-type map and now round-trips correctly. - Fixed `Serialise.load_custom_model` reconstructing `Event.event_type` as the bare enum name string (e.g. `"TERMINATION"`) instead of the corresponding `EventType` member. The custom JSON encoder writes Enum values as their `.name`, but the loader was passing the string straight through to `pybamm.Event.__init__`, so models round-tripped through `to_json` / `from_json` carried string event types rather than `EventType` enum values. ([#5498](https://github.com/pybamm-team/PyBaMM/pull/5498)) - Fixed `Serialise.serialise_solver` silently dropping `root_method` (and any other nested `BaseSolver` `__init__` argument). After construction, `root_method` is a `BaseSolver` instance rather than the original string, so it failed `json.dumps` and was omitted from the config — making the deserialised solver fall back to the default. `to_config()` / `from_config()` now recurse into nested solver values, preserving `root_method` and its tolerances across the round-trip. ([#5497](https://github.com/pybamm-team/PyBaMM/pull/5497)) diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index b931445495..a8e1d9a2b6 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -15,6 +15,14 @@ _UNSET = object() + +def _flatten_inputs(inputs_dict): + """Flatten ``{name: value}`` into a 1-D float array in dict-key order.""" + if not inputs_dict: + return np.zeros(0) + return np.concatenate([np.asarray(v).reshape(-1) for v in inputs_dict.values()]) + + # Mirrors SUNDIALS ``IDA_ROOT_RETURN`` in ``sundials/include/ida/ida.h``. # Returned by ``IDASolve`` (and surfaced via ``Solution.flag``) when the # integrator has located one or more root function zeros. @@ -546,6 +554,10 @@ def to_idaklu(fn): self._setup[name] = fn self._setup[f"{name}_pkl"] = pkl + # Used in _post_process_solution to set closest_event_idx without + # going via BaseSolver.get_termination_reason's per-event re-walk. + self._setup["rootfn_casadi"] = fns["rootfn"] + for key in self.output_variables: fn, pkl = to_idaklu(fns[f"var:{key}"]) self._setup["var_idaklu_fcns"].append(fn) @@ -603,6 +615,7 @@ def __getstate__(self): "mass_action", "sensfn", "rootfn", + "rootfn_casadi", "var_idaklu_fcns", "dvar_dy_idaklu_fcns", "dvar_dp_idaklu_fcns", @@ -631,6 +644,10 @@ def __setstate__(self, d): ]: self._setup[key] = idaklu.generate_function(self._setup[f"{key}_pkl"]) + self._setup["rootfn_casadi"] = casadi.Function.deserialize( + self._setup["rootfn_pkl"] + ) + for key in ["var_idaklu_fcns", "dvar_dy_idaklu_fcns", "dvar_dp_idaklu_fcns"]: self._setup[key] = [ idaklu.generate_function(f) for f in self._setup[f"{key}_pkl"] @@ -660,12 +677,7 @@ def _integrate( # stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters) if inputs_list and inputs_list[0]: - inputs = np.vstack( - [ - np.hstack([np.array(x).reshape(-1) for x in inputs_dict.values()]) - for inputs_dict in inputs_list - ] - ) + inputs = np.vstack([_flatten_inputs(d) for d in inputs_list]) else: inputs = np.array([[]] * len(inputs_list)) @@ -806,6 +818,18 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict, t_ev options=solution_options, ) + # Set closest_event_idx so BaseSolver.get_termination_reason doesn't + # re-walk every event's symbolic expression on the Python side. + if sol.flag == _IDA_ROOT_RETURN and self._setup["num_of_events"] > 0: + event_values = np.asarray( + self._setup["rootfn_casadi"]( + float(sol.t[-1]), + np.asarray(y_event).reshape(-1), + _flatten_inputs(inputs_dict), + ) + ).reshape(-1) + newsol.closest_event_idx = int(np.nanargmin(np.abs(event_values))) + newsol.integration_time = integration_time if not save_outputs_only: return newsol diff --git a/tests/memory/test_memory_leaks.py b/tests/memory/test_memory_leaks.py index e4eb3694e9..546edde1bb 100644 --- a/tests/memory/test_memory_leaks.py +++ b/tests/memory/test_memory_leaks.py @@ -274,3 +274,56 @@ def test_gitt_memory(self): peak_mb = peak / 1024 / 1024 assert peak_mb < 6, f"Peak memory {peak_mb:.1f} MB for GITT is excessive." + + def test_idaklu_event_termination_no_python_event_reeval(self): + """ + IDAKLUSolver with output_variables on a long event-terminated experiment + must not re-walk every TERMINATION event's symbolic expression on the + Python side after each step. That path goes through + StateVector._base_evaluate (per-leaf numpy fancy-indexing) and dominated + per-step allocation churn before the closest_event_idx fix. + + Allocations are transient (immediately freed each step), so peak-RSS / + tracemalloc snapshots can't see them. Count direct calls to + _base_evaluate instead — deterministic across runs. + """ + original = pybamm.StateVector._base_evaluate + call_count = 0 + + def counting(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + return original(self, *args, **kwargs) + + pybamm.StateVector._base_evaluate = counting + try: + experiment = pybamm.Experiment( + [ + ( + "Discharge at 1C until 3.0 V", + "Charge at 1C until 4.2 V", + "Hold at 4.2 V until C/50", + ) + ] + * 5, + period=300, + ) + sim = pybamm.Simulation( + pybamm.lithium_ion.SPM(), + experiment=experiment, + solver=pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]), + ) + sim.solve() + finally: + pybamm.StateVector._base_evaluate = original + + # Solve-time setup paths legitimately call _base_evaluate (initial + # conditions, event setup) — those scale with the model, not the + # cycle count. Per-step re-evaluation pushes the count past 1000 on + # a 5-cycle SPM run; the fix keeps it under 200. + assert call_count < 300, ( + f"StateVector._base_evaluate was called {call_count} times during " + f"a 5-cycle solve. IDAKLUSolver should set closest_event_idx so " + f"BaseSolver.get_termination_reason short-circuits instead of " + f"re-walking event expressions on every step." + ) diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index ad7fb5bdc2..f06b3c9352 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -1096,6 +1096,83 @@ def test_with_output_variables_and_event_termination(self): sol3 = sim3.solve(np.linspace(0, 3600, 2)) assert sol3.termination == "event: Minimum voltage [V]" + def test_closest_event_idx_set_after_root_return(self): + # IDAKLU must populate Solution.closest_event_idx after a root return so + # BaseSolver.get_termination_reason short-circuits instead of re-walking + # every TERMINATION event's symbolic expression on the Python side. That + # slow path generated tens of thousands of small numpy allocations per + # long event-terminated cycling run. + cycle = ( + "Discharge at 1C until 3.0 V", + "Charge at 1C until 4.2 V", + "Hold at 4.2 V until C/50", + ) + sim = pybamm.Simulation( + pybamm.lithium_ion.SPM(), + experiment=pybamm.Experiment([cycle] * 2, period=300), + solver=pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]), + ) + sim.solve() + + event_steps = [ + step + for cycle_sol in sim.solution.cycles + for step in cycle_sol.steps + if step.termination.startswith("event:") + ] + assert event_steps, "expected at least one event-terminated step" + # The index must also resolve to the same event name the slow path in + # BaseSolver.get_termination_reason would have picked. + for step in event_steps: + assert step.closest_event_idx is not None, ( + f"event-terminated step {step.termination!r} has " + f"closest_event_idx=None — BaseSolver will fall back to " + f"per-step Python event re-evaluation" + ) + terminate_events = [ + e + for e in step.all_models[-1].events + if e.event_type == pybamm.EventType.TERMINATION + ] + picked = terminate_events[step.closest_event_idx].name + assert step.termination == f"event: {picked}", ( + f"closest_event_idx={step.closest_event_idx} resolves to " + f"{picked!r}, but step.termination is {step.termination!r}" + ) + + def test_pickle_roundtrip_preserves_closest_event_idx(self): + # rootfn_casadi is dropped from the pickle and rebuilt from rootfn_pkl + # in __setstate__; confirm a round-tripped solver still records + # closest_event_idx after a root-return termination. + import pickle + + solver = pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]) + sim = pybamm.Simulation( + pybamm.lithium_ion.SPM(), + experiment=pybamm.Experiment( + [("Discharge at 1C until 3.0 V", "Charge at 1C until 4.2 V")] + ), + solver=solver, + ) + sim.solve() + + roundtripped = pickle.loads(pickle.dumps(solver)) + sim2 = pybamm.Simulation( + pybamm.lithium_ion.SPM(), + experiment=pybamm.Experiment( + [("Discharge at 1C until 3.0 V", "Charge at 1C until 4.2 V")] + ), + solver=roundtripped, + ) + sim2.solve() + + for step in sim2.solution.cycles[0].steps: + if step.termination.startswith("event:"): + assert step.closest_event_idx is not None, ( + "round-tripped IDAKLUSolver must still set " + "closest_event_idx after a root return" + ) + def test_simulation_period(self): model = pybamm.lithium_ion.DFN() parameter_values = pybamm.ParameterValues("Chen2020")