Skip to content
This repository was archived by the owner on Mar 25, 2025. It is now read-only.

Commit 00a0296

Browse files
Make "sparse" solver check if equations are linear.
If the system is linear, then newtons method always converges in exactly one iteration. When using the sparse solver on linear systems omit the newtons iteration and solve directly. This should make the resulting code run marginally faster by skipping the check for convergence. Currently the check for convergence is implemented as "error = sqrt(|F|^2)".
1 parent cde5dbf commit 00a0296

5 files changed

Lines changed: 50 additions & 42 deletions

File tree

nmodl/ode.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from importlib import import_module
99

1010
import sympy as sp
11+
import itertools
1112

1213
# import known_functions through low-level mechanism because the ccode
1314
# module is overwritten in sympy and contents of that submodule cannot be
@@ -272,6 +273,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
272273

273274
eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)
274275

276+
linear = _is_linear(eqs, state_vars, sympy_vars)
277+
275278
custom_fcts = _get_custom_functions(function_calls)
276279

277280
jacobian = sp.Matrix(eqs).jacobian(state_vars)
@@ -291,7 +294,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
291294
# interweave
292295
code = _interweave_eqs(vecFcode, vecJcode)
293296

294-
return code
297+
return code, linear
298+
299+
300+
def _is_linear(eqs, state_vars, sympy_vars):
301+
for expr in eqs:
302+
for (x, y) in itertools.combinations_with_replacement(state_vars, 2):
303+
try:
304+
if not sp.Eq(sp.diff(expr, x, y), 0):
305+
return False
306+
except TypeError:
307+
return False
308+
return True
295309

296310

297311
def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):

src/pybind/pyembed.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor {
5656
// output
5757
// returns a vector of solutions, i.e. new statements to add to block:
5858
std::vector<std::string> solutions;
59+
// returns if the system is linear or not.
60+
bool linear;
5961
// may also return a python exception message:
6062
std::string exception_message;
6163

src/pybind/wrapper.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,22 @@ void SolveNonLinearSystemExecutor::operator()() {
6666
from nmodl.ode import solve_non_lin_system
6767
exception_message = ""
6868
try:
69-
solutions = solve_non_lin_system(equation_strings,
69+
solutions, linear = solve_non_lin_system(equation_strings,
7070
state_vars,
7171
vars,
7272
function_calls)
7373
except Exception as e:
7474
# if we fail, fail silently and return empty string
7575
solutions = [""]
76+
linear = False
7677
new_local_vars = [""]
7778
exception_message = str(e)
7879
)",
7980
py::globals(),
8081
locals);
8182
// returns a vector of solutions, i.e. new statements to add to block:
8283
solutions = locals["solutions"].cast<std::vector<std::string>>();
84+
linear = locals["linear"].cast<bool>();
8385
// may also return a python exception message:
8486
exception_message = locals["exception_message"].cast<std::string>();
8587
}

src/visitors/sympy_solver_visitor.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system(
356356
(*solver)();
357357
// returns a vector of solutions, i.e. new statements to add to block:
358358
auto solutions = solver->solutions;
359+
bool linear = solver->linear;
359360
// may also return a python exception message:
360361
auto exception_message = solver->exception_message;
361362
pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver);
@@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system(
364365
exception_message);
365366
return;
366367
}
367-
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
368-
construct_eigen_solver_block(pre_solve_statements, solutions, false);
368+
if (!linear) {
369+
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
370+
}
371+
else {
372+
logger->debug("SympySolverVisitor :: Constructing eigen solve block");
373+
}
374+
construct_eigen_solver_block(pre_solve_statements, solutions, linear);
369375
}
370376

371377
void SympySolverVisitor::visit_var_name(ast::VarName& node) {

0 commit comments

Comments
 (0)