Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,11 +885,26 @@ def _save_t1(subsaveat, save_state):
return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats


def _validate_solver(solver: Any) -> AbstractSolver:
"""Raise a clear error if `solver` was passed as a class, not an instance."""
if isinstance(solver, AbstractSolver):
return solver
if isinstance(solver, type) and issubclass(solver, AbstractSolver):
raise ValueError(
"It looks like you forgot to instantiate your solver, e.g. by passing "
"`diffrax.Euler` instead of `diffrax.Euler()`."
)
raise ValueError(
"Argument `solver` must be an instance of (some subclass of) "
"`diffrax.AbstractSolver`, but its type is not recognised."
)


@eqx.filter_jit
@eqxi.doc_remove_args("discrete_terminating_event")
def diffeqsolve(
terms: PyTree[AbstractTerm],
solver: AbstractSolver,
solver: AbstractSolver | type[AbstractSolver],
t0: RealScalarLike,
t1: RealScalarLike,
dt0: RealScalarLike | None,
Expand Down Expand Up @@ -1014,6 +1029,8 @@ def diffeqsolve(
# Initial set-up
#

validated_solver: AbstractSolver = _validate_solver(solver)

# Backward compatibility
if discrete_terminating_event is not None:
warnings.warn(
Expand Down Expand Up @@ -1100,22 +1117,22 @@ def _promote(yi):
del timelikes

# Backward compatibility
if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)):
if isinstance(validated_solver, (EulerHeun, ItoMilstein, StratonovichMilstein)):
try:
_assert_term_compatible(
t0,
y0,
args,
terms,
(ODETerm, AbstractTerm),
solver.term_compatible_contr_kwargs,
validated_solver.term_compatible_contr_kwargs,
)
except Exception as _:
pass
else:
warnings.warn(
"Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to "
f"{solver.__class__.__name__} is deprecated in favour of "
f"{validated_solver.__class__.__name__} is deprecated in favour of "
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
"the same terms can now be passed used for both general "
"and SDE-specific solvers!",
Expand All @@ -1129,20 +1146,22 @@ def _promote(yi):
y0,
args,
terms,
solver.term_structure,
solver.term_compatible_contr_kwargs,
validated_solver.term_structure,
validated_solver.term_compatible_contr_kwargs,
)

if is_sde(terms):
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
if not isinstance(
validated_solver, (AbstractItoSolver, AbstractStratonovichSolver)
):
warnings.warn(
f"`{type(solver).__name__}` is not marked as converging to either the "
"Itô or the Stratonovich solution.",
f"`{type(validated_solver).__name__}` is not marked as converging to "
"either the Itô or the Stratonovich solution.",
stacklevel=2,
)
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
# Specific check to not work even if using HalfSolver(Euler())
if isinstance(solver, Euler):
if isinstance(validated_solver, Euler):
raise ValueError(
"An SDE should not be solved with adaptive step sizes with Euler's "
"method, as it may not converge to the correct solution."
Expand Down Expand Up @@ -1175,26 +1194,28 @@ def _wrap(term):
is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
)

if isinstance(solver, AbstractImplicitSolver):
if isinstance(validated_solver, AbstractImplicitSolver):

def _get_tols(x):
outs = []
for attr in ("rtol", "atol", "norm"):
if (
getattr(cast(AbstractImplicitSolver, solver).root_finder, attr)
getattr(
cast(AbstractImplicitSolver, validated_solver).root_finder, attr
)
is use_stepsize_tol
):
outs.append(getattr(x, attr))
return tuple(outs)

if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
solver = eqx.tree_at(
validated_solver = eqx.tree_at(
lambda s: _get_tols(s.root_finder),
solver,
validated_solver,
_get_tols(stepsize_controller),
)
else:
if len(_get_tols(solver.root_finder)) > 0:
if len(_get_tols(validated_solver.root_finder)) > 0:
raise ValueError(
"A fixed step size controller is being used alongside an implicit "
"solver, but the tolerances for the implicit solver have not been "
Expand Down Expand Up @@ -1248,24 +1269,24 @@ def _subsaveat_direction_fn(x):

# Initialise states
tprev = t0
error_order = solver.error_order(terms)
error_order = validated_solver.error_order(terms)
if controller_state is None:
passed_controller_state = False
(tnext, controller_state) = stepsize_controller.init(
terms, t0, t1, y0, dt0, args, solver.func, error_order
terms, t0, t1, y0, dt0, args, validated_solver.func, error_order
)
else:
passed_controller_state = True
if dt0 is None:
(tnext, _) = stepsize_controller.init(
terms, t0, t1, y0, dt0, args, solver.func, error_order
terms, t0, t1, y0, dt0, args, validated_solver.func, error_order
)
else:
tnext = t0 + dt0
tnext = jnp.minimum(tnext, t1)
if solver_state is None:
passed_solver_state = False
solver_state = solver.init(terms, t0, tnext, y0, args)
solver_state = validated_solver.init(terms, t0, tnext, y0, args)
else:
passed_solver_state = True

Expand Down Expand Up @@ -1310,7 +1331,14 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
result = RESULTS.successful
if saveat.dense or event is not None:
_, _, dense_info_struct, _, _ = eqx.filter_eval_shape(
solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump
validated_solver.step,
terms,
tprev,
tnext,
y0,
args,
solver_state,
made_jump,
)
if saveat.dense:
if max_steps is None:
Expand Down Expand Up @@ -1371,7 +1399,7 @@ def _outer_cond_fn(cond_fn_i):
y0,
args,
terms=terms,
solver=solver,
solver=validated_solver,
t0=t0,
t1=t1,
dt0=dt0,
Expand Down Expand Up @@ -1456,7 +1484,7 @@ def _outer_cond_fn(cond_fn_i):
final_state, aux_stats = adjoint.loop(
args=args,
terms=terms,
solver=solver,
solver=validated_solver,
stepsize_controller=stepsize_controller,
event=event,
saveat=saveat,
Expand Down Expand Up @@ -1503,7 +1531,7 @@ def _outer_cond_fn(cond_fn_i):
ts=final_state.dense_ts,
ts_size=final_state.dense_save_index + 1,
infos=final_state.dense_infos,
interpolation_cls=solver.interpolation_cls,
interpolation_cls=validated_solver.interpolation_cls,
direction=direction,
t0_if_trivial=t0,
y0_if_trivial=y0,
Expand Down
29 changes: 29 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,35 @@ def vector_field(t, y, args):
diffrax._integrate._PRINT_STATIC = False


def test_uninstantiated_solver_error():
msg = (
r"It looks like you forgot to instantiate your solver, e.g. by passing "
r"`diffrax\.Euler` instead of `diffrax\.Euler\(\)`."
)
term = ODETerm(lambda t, y, args: -y)
with pytest.raises(ValueError, match=msg):
diffrax.diffeqsolve(term, diffrax.Euler, 0, 1, 0.1, 1.0)
with pytest.raises(ValueError, match=msg):
diffrax.diffeqsolve(
MultiTerm(
ODETerm(lambda t, y, args: -y),
ControlTerm(
lambda t, y, args: 0.1 * t,
diffrax.VirtualBrownianTree(
0, 1, tol=1e-3, shape=(), key=jr.key(0)
),
),
),
diffrax.EulerHeun,
0,
1,
0.1,
1.0,
)
with pytest.raises(ValueError, match=r"not recognised"):
diffrax._integrate._validate_solver("not a solver")


def test_implicit_tol_error():
msg = "the tolerances for the implicit solver have not been specified"
with pytest.raises(ValueError, match=msg):
Expand Down