diff --git a/sunode/test_solve.py b/sunode/test_solve.py index dce37d1..3b14355 100644 --- a/sunode/test_solve.py +++ b/sunode/test_solve.py @@ -152,3 +152,26 @@ def rhs(t, y, p): solver = AdjointSolver(problem) check_call_solve(solver, param_vals, "backward") + + +def test_linear_solver_kwarg(): + def rhs(t, y, p): + return { + 'x': y.x, + } + + states = { + 'x': (), + } + + params = { + 'b': () + } + param_vals = { + 'b': 0.2 + } + problem = SympyProblem(params, states, rhs, derivative_params=[]) + linear_solver_opts = ["dense", "dense_finitediff", "spgmr_finitediff", "spgmr"] + for linear_solver in linear_solver_opts: + solver = Solver(problem, linear_solver=linear_solver) + check_call_solve(solver, param_vals, None)