Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,19 @@ The following flags exist:
- Default
- Description
* - ``disable_analytic_solver``
- False
- :python:`False`
- Set to True to return numerical solver recommendations, and no propagators, even for ODEs that are analytically tractable.
* - ``disable_stiffness_check``
- False
- :python:`False`
- Set to True to disable stiffness check.
* - ``disable_singularity_detection``
- False
- :python:`False`
- Set to True to disable detection of conditions under which numerical singularities (division by zero) could occur.
* - ``use_alternative_expM``
- :python:`False`
- If :python:`False`, use the sympy function ``sympy.exp`` to compute the matrix exponential. If :python:`True`, use an alternative function (see :py:func:`odetoolbox.sympy_helpers.expMt` for details). This can be useful as calls to ``sympy.exp`` can sometimes take a very large amount of time.
* - ``preserve_expressions``
- False
- :python:`False`
- Set to True, or a list of strings corresponding to individual variable names, to disable internal rewriting of expressions, and return same output as input expression where possible. Only applies to variables specified as first-order differential equations.
* - ``log_level``
- :python:`logging.WARN`
Expand Down
9 changes: 6 additions & 3 deletions odetoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _get_all_first_order_variables(indict) -> Iterable[str]:
return variable_names


def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]:
def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, use_alternative_expM: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]:
r"""
Like analysis(), but additionally returns ``shape_sys`` and ``shapes``.

Expand Down Expand Up @@ -251,7 +251,7 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
if analytic_syms:
logging.info("Generating propagators for the following symbols: " + ", ".join([str(k) for k in analytic_syms]))
sub_sys = shape_sys.get_sub_system(analytic_syms)
analytic_solver_json = sub_sys.generate_propagator_solver(disable_singularity_detection=disable_singularity_detection)
analytic_solver_json = sub_sys.generate_propagator_solver(disable_singularity_detection=disable_singularity_detection, use_alternative_expM=use_alternative_expM)
analytic_solver_json["solver"] = "analytical"
solvers_json.append(analytic_solver_json)

Expand Down Expand Up @@ -405,13 +405,15 @@ def _init_logging(log_level: Union[str, int] = logging.WARNING):
logging.getLogger().setLevel(log_level)


def analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> List[Dict]:
def analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, use_alternative_expM: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> List[Dict]:
r"""
The main entry point of the ODE-toolbox API.

:param indict: Input dictionary for the analysis. For details, see https://ode-toolbox.readthedocs.io/en/master/#input
:param disable_stiffness_check: Whether to perform stiffness checking.
:param disable_analytic_solver: Set to True to return numerical solver recommendations, and no propagators, even for ODEs that are analytically tractable.
:param disable_singularity_detection: Set to True to disable detection of conditions under which numerical singularities (division by zero) could occur in the generated analytic solver. This can be useful for analytic solvers containing a large amount of conditions, which could take a long time to compute. If True, at most one analytic solver will be returned, in which numerical singularities could occur.
:param use_alternative_expM: If :python:`False`, use the sympy function ``sympy.exp`` to compute the matrix exponential. If :python:`True`, use an alternative function (see :py:func:`odetoolbox.sympy_helpers.expMt` for details). This can be useful as calls to ``sympy.exp`` can sometimes take a very large amount of time.
:param preserve_expressions: Set to True, or a list of strings corresponding to individual variable names, to disable internal rewriting of expressions, and return same output as input expression where possible. Only applies to variables specified as first-order differential equations.
:param log_level: Sets the logging threshold. Logging messages which are less severe than ``log_level`` will be ignored. Log levels can be provided as an integer or string, for example "INFO" (more messages) or "WARN" (fewer messages). For a list of valid logging levels, see https://docs.python.org/3/library/logging.html#logging-levels

Expand All @@ -421,6 +423,7 @@ def analysis(indict, disable_stiffness_check: bool = False, disable_analytic_sol
disable_stiffness_check=disable_stiffness_check,
disable_analytic_solver=disable_analytic_solver,
disable_singularity_detection=disable_singularity_detection,
use_alternative_expM=use_alternative_expM,
preserve_expressions=preserve_expressions,
log_level=log_level)
return d
68 changes: 68 additions & 0 deletions odetoolbox/sympy_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,71 @@ def _print_Function(self, expr):
return expr.func.__name__.lower() + "(%s)" % self.stringify(expr.args, ", ")

return expr.func.__name__ + "(%s)" % self.stringify(expr.args, ", ")


def expMt(M, t=1):
"""Compute matrix exponential exp(M*t).

Based on code contributed by GitHub user @oscarbenjamin, July 29th, 2021 [1]_ (see also the discussion at [2]_).

.. [1] https://github.com/sympy/sympy/issues/21585
.. [2] https://github.com/nest/ode-toolbox/pull/97
"""
def ilt(e, s, t):
"""Fast inverse Laplace transform of rational function including RootSum"""
a, b, n = sympy.symbols('a, b, n', cls=sympy.Wild, exclude=[s])

def _ilt(e):
if not e.has(s):
return e
elif e.is_Add:
return _ilt_add(e)
elif e.is_Mul:
return _ilt_mul(e)
elif e.is_Pow:
return _ilt_pow(e)
elif isinstance(e, sympy.RootSum):
return _ilt_rootsum(e)
else:
raise NotImplementedError

def _ilt_add(e):
return e.func(*map(_ilt, e.args))

def _ilt_mul(e):
coeff, expr = e.as_independent(s)
if expr.is_Mul:
raise NotImplementedError
return coeff * _ilt(expr)

def _ilt_pow(e):
match = e.match((a * s + b)**n)
if match is not None:
nm, am, bm = match[n], match[a], match[b]
if nm.is_Integer and nm < 0:
if nm == 1:
return sympy.exp(-(bm / am) * t) / am
else:
return t**(-nm - 1) * sympy.exp(-(bm / am) * t) / (am**-nm * sympy.gamma(-nm))
raise NotImplementedError

def _ilt_rootsum(e):
expr = e.fun.expr
[variable] = e.fun.variables
return sympy.RootSum(e.poly, sympy.Lambda(variable, sympy.together(_ilt(expr))))

return _ilt(e)

assert M.is_square
N = M.shape[0]
s = sympy.Dummy("s")

Ms = (s * sympy.eye(N) - M)
Mres = Ms.adjugate() / Ms.det()

def expMij(i, j):
"""Partial fraction expansion then inverse Laplace transform"""
Mresij_pfe = sympy.apart(Mres[i, j], s, full=True)
return ilt(Mresij_pfe, s, t)

return sympy.Matrix(N, N, expMij)
16 changes: 11 additions & 5 deletions odetoolbox/system_of_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .config import Config
from .shapes import Shape
from .singularity_detection import SingularityDetection, SingularityDetectionException
from .sympy_helpers import _custom_simplify_expr, _is_zero
from .sympy_helpers import _custom_simplify_expr, _is_zero, expMt


class GetBlockDiagonalException(Exception):
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_sub_system(self, symbols):
return SystemOfShapes(x_sub, A_sub, b_sub, c_sub, shapes_sub)


def _generate_propagator_matrix(self, A):
def _generate_propagator_matrix(self, A, use_alternative_expM: bool = False):
r"""Generate the propagator matrix by matrix exponentiation."""

# naive: calculate propagators in one step
Expand All @@ -214,7 +214,13 @@ def _generate_propagator_matrix(self, A):
# optimized: be explicit about block diagonal elements; much faster!
try:
blocks = get_block_diagonal_blocks(np.array(A))
propagators = [sympy.simplify(sympy.exp(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol))) for block in blocks]

if use_alternative_expM:
expM = expMt
else:
expM = sympy.exp

propagators = [sympy.simplify(expM(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol, real=True))) for block in blocks]
P = sympy.Matrix(scipy.linalg.block_diag(*propagators))
except GetBlockDiagonalException:
# naive: calculate propagators in one step
Expand All @@ -226,12 +232,12 @@ def _generate_propagator_matrix(self, A):

return P

def generate_propagator_solver(self, disable_singularity_detection: bool = False):
def generate_propagator_solver(self, disable_singularity_detection: bool = False, use_alternative_expM: bool = False):
r"""
Generate the propagator matrix and symbolic expressions for propagator-based updates; return as JSON.
"""

P = self._generate_propagator_matrix(self.A_)
P = self._generate_propagator_matrix(self.A_, use_alternative_expM=use_alternative_expM)

#
# singularity detection
Expand Down
6 changes: 4 additions & 2 deletions tests/test_analytic_solver_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import math
import numpy as np
import os
import pytest
import sympy
import sympy.parsing.sympy_parser
import scipy
Expand Down Expand Up @@ -97,7 +98,8 @@ class TestAnalyticSolverIntegration:
\mathbf{z}(t + h) = \mathbf{P} \cdot \mathbf{z}(t)
"""

def test_analytic_solver_integration_psc_alpha(self):
@pytest.mark.parametrize("use_alternative_expM", [True, False])
def test_analytic_solver_integration_psc_alpha(self, use_alternative_expM: bool):
h = 1E-3 # [s]
T = 20E-3 # [s]

Expand Down Expand Up @@ -190,7 +192,7 @@ def f(t, y):

print("Starting ODE-toolbox analysis...")
indict = _open_json("test_integration.json")
solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True, log_level="DEBUG")
solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True, use_alternative_expM=use_alternative_expM, log_level="DEBUG")
assert len(solver_dict) == 1
solver_dict = solver_dict[0]
assert solver_dict["solver"] == "analytical"
Expand Down
11 changes: 7 additions & 4 deletions tests/test_double_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#

import numpy as np
import pytest
from scipy.integrate import odeint

import odetoolbox
Expand All @@ -39,7 +40,8 @@
class TestDoubleExponential:
r"""Test propagators generation for double exponential"""

def test_double_exponential(self):
@pytest.mark.parametrize("use_alternative_expM", [True, False])
def test_double_exponential(self, use_alternative_expM: bool):
r"""Test propagators generation for double exponential"""

def time_to_max(tau_1, tau_2):
Expand Down Expand Up @@ -95,7 +97,7 @@ def flow(y, t, tau_1, tau_2, alpha, dt):
ODE_INITIAL_VALUES = {"I": 0., "I_aux": 0.}

# simulate with ode-toolbox
solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True)
solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True, use_alternative_expM=use_alternative_expM)
assert len(solver_dict) == 1
solver_dict = solver_dict[0]
assert solver_dict["solver"] == "analytical"
Expand Down Expand Up @@ -152,12 +154,13 @@ def flow(y, t, tau_1, tau_2, alpha, dt):

np.testing.assert_allclose(y_[:, 1], rec_I_interp, atol=1E-7)

def test_constant_factors_double_exponential(self):
@pytest.mark.parametrize("use_alternative_expM", [True, False])
def test_constant_factors_double_exponential(self, use_alternative_expM: bool):
r"""Test the computation of propagators for an alpha (double-exponential) kernel with constant coefficients; this tests the block-diagonal computation of propagators."""
indict = {"dynamics": [{"expression": "x'' = -2 * x' - x",
"initial_values": {"x": "0",
"x'": "0"}}]}
solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True)
solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True, use_alternative_expM=use_alternative_expM)
assert len(solver_dict) == 1
solver_dict = solver_dict[0]
assert solver_dict["solver"] == "analytical"
Expand Down
Loading