From 7214553a6ea6fe99f8f1371c2c9aee158d0b9b83 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Wed, 26 Nov 2025 00:05:50 +0100 Subject: [PATCH 1/4] add alternative matrix exponential function --- doc/index.rst | 4 ++ odetoolbox/config.py | 3 +- odetoolbox/sympy_helpers.py | 67 ++++++++++++++++++++++++++++++++++ odetoolbox/system_of_shapes.py | 10 ++++- 4 files changed, 81 insertions(+), 3 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index 05380335..6e4a0a48 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -360,6 +360,10 @@ The following global options are defined. Note that all are typically formatted - :python:`["oo", "zoo", "nan", "NaN", "__h"]` - list of strings - For each forbidden name: emit an error if a variable or parameter by this name occurs in the input. + * - ``use_alternative_expM`` + - :python:`False` + - boolean + - If :python:`False`, use the sympy function ``sympy.exp`` to compute the matrix exponential. If :python:`True`, use an alternative function (see :py:func:`odetoolbox.sympy_helpers.expMt` for details). This can be useful as calls to ``sympy.exp`` can sometimes take a very large amount of time. Output diff --git a/odetoolbox/config.py b/odetoolbox/config.py index 09444a6c..43f3f6a5 100644 --- a/odetoolbox/config.py +++ b/odetoolbox/config.py @@ -36,7 +36,8 @@ class Config: "max_step_size": 999., "integration_accuracy_abs": 1E-6, "integration_accuracy_rel": 1E-6, - "forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"] + "forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"], + "use_alternative_expM": False } def __getitem__(self, key): diff --git a/odetoolbox/sympy_helpers.py b/odetoolbox/sympy_helpers.py index 9c1ac74c..6696f44d 100644 --- a/odetoolbox/sympy_helpers.py +++ b/odetoolbox/sympy_helpers.py @@ -166,3 +166,70 @@ def _print_Function(self, expr): return expr.func.__name__.lower() + "(%s)" % self.stringify(expr.args, ", ") return expr.func.__name__ + "(%s)" % self.stringify(expr.args, ", ") + + +def expMt(M, t=1): + """Compute matrix exponential exp(M*t). + + Based on code contributed by GitHub user @oscarbenjamin, July 29th, 2021 [1]_. + + .. [1] https://github.com/sympy/sympy/issues/21585 + """ + def ilt(e, s, t): + """Fast inverse Laplace transform of rational function including RootSum""" + a, b, n = sympy.symbols('a, b, n', cls=sympy.Wild, exclude=[s]) + + def _ilt(e): + if not e.has(s): + return e + elif e.is_Add: + return _ilt_add(e) + elif e.is_Mul: + return _ilt_mul(e) + elif e.is_Pow: + return _ilt_pow(e) + elif isinstance(e, sympy.RootSum): + return _ilt_rootsum(e) + else: + raise NotImplementedError + + def _ilt_add(e): + return e.func(*map(_ilt, e.args)) + + def _ilt_mul(e): + coeff, expr = e.as_independent(s) + if expr.is_Mul: + raise NotImplementedError + return coeff * _ilt(expr) + + def _ilt_pow(e): + match = e.match((a * s + b)**n) + if match is not None: + nm, am, bm = match[n], match[a], match[b] + if nm.is_Integer and nm < 0: + if nm == 1: + return sympy.exp(-(bm / am) * t) / am + else: + return t**(-nm - 1) * sympy.exp(-(bm / am) * t) / (am**-nm * sympy.gamma(-nm)) + raise NotImplementedError + + def _ilt_rootsum(e): + expr = e.fun.expr + [variable] = e.fun.variables + return sympy.RootSum(e.poly, sympy.Lambda(variable, sympy.together(_ilt(expr)))) + + return _ilt(e) + + assert M.is_square + N = M.shape[0] + s = sympy.Dummy("s") + + Ms = (s * sympy.eye(N) - M) + Mres = Ms.adjugate() / Ms.det() + + def expMij(i, j): + """Partial fraction expansion then inverse Laplace transform""" + Mresij_pfe = sympy.apart(Mres[i, j], s, full=True) + return ilt(Mresij_pfe, s, t) + + return sympy.Matrix(N, N, expMij) diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index 0ae0c07d..357efc4d 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -32,7 +32,7 @@ from .config import Config from .shapes import Shape from .singularity_detection import SingularityDetection, SingularityDetectionException -from .sympy_helpers import _custom_simplify_expr, _is_zero +from .sympy_helpers import _custom_simplify_expr, _is_zero, expMt class GetBlockDiagonalException(Exception): @@ -214,7 +214,13 @@ def _generate_propagator_matrix(self, A): # optimized: be explicit about block diagonal elements; much faster! try: blocks = get_block_diagonal_blocks(np.array(A)) - propagators = [sympy.simplify(sympy.exp(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol))) for block in blocks] + + if Config().use_alternative_expM: + expM = expMt + else: + expM = sympy.exp + + propagators = [sympy.simplify(expM(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol, real=True))) for block in blocks] P = sympy.Matrix(scipy.linalg.block_diag(*propagators)) except GetBlockDiagonalException: # naive: calculate propagators in one step From 7909ea5660a1481bc9e9bc58aa3dce57ca65cde1 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Sun, 28 Dec 2025 13:26:45 +0100 Subject: [PATCH 2/4] add alternative matrix exponential function --- odetoolbox/sympy_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/odetoolbox/sympy_helpers.py b/odetoolbox/sympy_helpers.py index 6696f44d..aecf8b8c 100644 --- a/odetoolbox/sympy_helpers.py +++ b/odetoolbox/sympy_helpers.py @@ -171,9 +171,10 @@ def _print_Function(self, expr): def expMt(M, t=1): """Compute matrix exponential exp(M*t). - Based on code contributed by GitHub user @oscarbenjamin, July 29th, 2021 [1]_. + Based on code contributed by GitHub user @oscarbenjamin, July 29th, 2021 [1]_ (see also the discussion at [2]_). .. [1] https://github.com/sympy/sympy/issues/21585 + .. [2] https://github.com/nest/ode-toolbox/pull/97 """ def ilt(e, s, t): """Fast inverse Laplace transform of rational function including RootSum""" From 23c679ca72ce8aef6985fed4bfd77c3bb0732941 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Sun, 11 Jan 2026 21:08:56 +0100 Subject: [PATCH 3/4] move use_alternative_expM from global config to analysis call --- doc/index.rst | 15 +++++++-------- odetoolbox/__init__.py | 9 ++++++--- odetoolbox/config.py | 3 +-- odetoolbox/system_of_shapes.py | 8 ++++---- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index 6e4a0a48..87bbf643 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -127,16 +127,19 @@ The following flags exist: - Default - Description * - ``disable_analytic_solver`` - - False + - :python:`False` - Set to True to return numerical solver recommendations, and no propagators, even for ODEs that are analytically tractable. * - ``disable_stiffness_check`` - - False + - :python:`False` - Set to True to disable stiffness check. * - ``disable_singularity_detection`` - - False + - :python:`False` - Set to True to disable detection of conditions under which numerical singularities (division by zero) could occur. + * - ``use_alternative_expM`` + - :python:`False` + - If :python:`False`, use the sympy function ``sympy.exp`` to compute the matrix exponential. If :python:`True`, use an alternative function (see :py:func:`odetoolbox.sympy_helpers.expMt` for details). This can be useful as calls to ``sympy.exp`` can sometimes take a very large amount of time. * - ``preserve_expressions`` - - False + - :python:`False` - Set to True, or a list of strings corresponding to individual variable names, to disable internal rewriting of expressions, and return same output as input expression where possible. Only applies to variables specified as first-order differential equations. * - ``log_level`` - :python:`logging.WARN` @@ -360,10 +363,6 @@ The following global options are defined. Note that all are typically formatted - :python:`["oo", "zoo", "nan", "NaN", "__h"]` - list of strings - For each forbidden name: emit an error if a variable or parameter by this name occurs in the input. - * - ``use_alternative_expM`` - - :python:`False` - - boolean - - If :python:`False`, use the sympy function ``sympy.exp`` to compute the matrix exponential. If :python:`True`, use an alternative function (see :py:func:`odetoolbox.sympy_helpers.expMt` for details). This can be useful as calls to ``sympy.exp`` can sometimes take a very large amount of time. Output diff --git a/odetoolbox/__init__.py b/odetoolbox/__init__.py index 89c42cc6..340c9ce8 100644 --- a/odetoolbox/__init__.py +++ b/odetoolbox/__init__.py @@ -182,7 +182,7 @@ def _get_all_first_order_variables(indict) -> Iterable[str]: return variable_names -def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]: +def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, use_alternative_expM: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]: r""" Like analysis(), but additionally returns ``shape_sys`` and ``shapes``. @@ -251,7 +251,7 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so if analytic_syms: logging.info("Generating propagators for the following symbols: " + ", ".join([str(k) for k in analytic_syms])) sub_sys = shape_sys.get_sub_system(analytic_syms) - analytic_solver_json = sub_sys.generate_propagator_solver(disable_singularity_detection=disable_singularity_detection) + analytic_solver_json = sub_sys.generate_propagator_solver(disable_singularity_detection=disable_singularity_detection, use_alternative_expM=use_alternative_expM) analytic_solver_json["solver"] = "analytical" solvers_json.append(analytic_solver_json) @@ -405,13 +405,15 @@ def _init_logging(log_level: Union[str, int] = logging.WARNING): logging.getLogger().setLevel(log_level) -def analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> List[Dict]: +def analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, use_alternative_expM: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> List[Dict]: r""" The main entry point of the ODE-toolbox API. :param indict: Input dictionary for the analysis. For details, see https://ode-toolbox.readthedocs.io/en/master/#input :param disable_stiffness_check: Whether to perform stiffness checking. :param disable_analytic_solver: Set to True to return numerical solver recommendations, and no propagators, even for ODEs that are analytically tractable. + :param disable_singularity_detection: Set to True to disable detection of conditions under which numerical singularities (division by zero) could occur in the generated analytic solver. This can be useful for analytic solvers containing a large amount of conditions, which could take a long time to compute. If True, at most one analytic solver will be returned, in which numerical singularities could occur. + :param use_alternative_expM: If :python:`False`, use the sympy function ``sympy.exp`` to compute the matrix exponential. If :python:`True`, use an alternative function (see :py:func:`odetoolbox.sympy_helpers.expMt` for details). This can be useful as calls to ``sympy.exp`` can sometimes take a very large amount of time. :param preserve_expressions: Set to True, or a list of strings corresponding to individual variable names, to disable internal rewriting of expressions, and return same output as input expression where possible. Only applies to variables specified as first-order differential equations. :param log_level: Sets the logging threshold. Logging messages which are less severe than ``log_level`` will be ignored. Log levels can be provided as an integer or string, for example "INFO" (more messages) or "WARN" (fewer messages). For a list of valid logging levels, see https://docs.python.org/3/library/logging.html#logging-levels @@ -421,6 +423,7 @@ def analysis(indict, disable_stiffness_check: bool = False, disable_analytic_sol disable_stiffness_check=disable_stiffness_check, disable_analytic_solver=disable_analytic_solver, disable_singularity_detection=disable_singularity_detection, + use_alternative_expM=use_alternative_expM, preserve_expressions=preserve_expressions, log_level=log_level) return d diff --git a/odetoolbox/config.py b/odetoolbox/config.py index 43f3f6a5..09444a6c 100644 --- a/odetoolbox/config.py +++ b/odetoolbox/config.py @@ -36,8 +36,7 @@ class Config: "max_step_size": 999., "integration_accuracy_abs": 1E-6, "integration_accuracy_rel": 1E-6, - "forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"], - "use_alternative_expM": False + "forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"] } def __getitem__(self, key): diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index 357efc4d..e83355a8 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -205,7 +205,7 @@ def get_sub_system(self, symbols): return SystemOfShapes(x_sub, A_sub, b_sub, c_sub, shapes_sub) - def _generate_propagator_matrix(self, A): + def _generate_propagator_matrix(self, A, use_alternative_expM: bool = False): r"""Generate the propagator matrix by matrix exponentiation.""" # naive: calculate propagators in one step @@ -215,7 +215,7 @@ def _generate_propagator_matrix(self, A): try: blocks = get_block_diagonal_blocks(np.array(A)) - if Config().use_alternative_expM: + if use_alternative_expM: expM = expMt else: expM = sympy.exp @@ -232,12 +232,12 @@ def _generate_propagator_matrix(self, A): return P - def generate_propagator_solver(self, disable_singularity_detection: bool = False): + def generate_propagator_solver(self, disable_singularity_detection: bool = False, use_alternative_expM: bool = False): r""" Generate the propagator matrix and symbolic expressions for propagator-based updates; return as JSON. """ - P = self._generate_propagator_matrix(self.A_) + P = self._generate_propagator_matrix(self.A_, use_alternative_expM=use_alternative_expM) # # singularity detection From 22e7628e4f6c3cc3fa73687a7cdc4e4e13e09a41 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Wed, 28 Jan 2026 14:43:58 +0100 Subject: [PATCH 4/4] add tests for alternative expM function --- tests/test_analytic_solver_integration.py | 6 ++++-- tests/test_double_exponential.py | 11 +++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/test_analytic_solver_integration.py b/tests/test_analytic_solver_integration.py index 780698b6..174ecd89 100644 --- a/tests/test_analytic_solver_integration.py +++ b/tests/test_analytic_solver_integration.py @@ -22,6 +22,7 @@ import math import numpy as np import os +import pytest import sympy import sympy.parsing.sympy_parser import scipy @@ -97,7 +98,8 @@ class TestAnalyticSolverIntegration: \mathbf{z}(t + h) = \mathbf{P} \cdot \mathbf{z}(t) """ - def test_analytic_solver_integration_psc_alpha(self): + @pytest.mark.parametrize("use_alternative_expM", [True, False]) + def test_analytic_solver_integration_psc_alpha(self, use_alternative_expM: bool): h = 1E-3 # [s] T = 20E-3 # [s] @@ -190,7 +192,7 @@ def f(t, y): print("Starting ODE-toolbox analysis...") indict = _open_json("test_integration.json") - solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True, log_level="DEBUG") + solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True, use_alternative_expM=use_alternative_expM, log_level="DEBUG") assert len(solver_dict) == 1 solver_dict = solver_dict[0] assert solver_dict["solver"] == "analytical" diff --git a/tests/test_double_exponential.py b/tests/test_double_exponential.py index 025d5839..7e0c3b3e 100644 --- a/tests/test_double_exponential.py +++ b/tests/test_double_exponential.py @@ -20,6 +20,7 @@ # import numpy as np +import pytest from scipy.integrate import odeint import odetoolbox @@ -39,7 +40,8 @@ class TestDoubleExponential: r"""Test propagators generation for double exponential""" - def test_double_exponential(self): + @pytest.mark.parametrize("use_alternative_expM", [True, False]) + def test_double_exponential(self, use_alternative_expM: bool): r"""Test propagators generation for double exponential""" def time_to_max(tau_1, tau_2): @@ -95,7 +97,7 @@ def flow(y, t, tau_1, tau_2, alpha, dt): ODE_INITIAL_VALUES = {"I": 0., "I_aux": 0.} # simulate with ode-toolbox - solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True) + solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True, use_alternative_expM=use_alternative_expM) assert len(solver_dict) == 1 solver_dict = solver_dict[0] assert solver_dict["solver"] == "analytical" @@ -152,12 +154,13 @@ def flow(y, t, tau_1, tau_2, alpha, dt): np.testing.assert_allclose(y_[:, 1], rec_I_interp, atol=1E-7) - def test_constant_factors_double_exponential(self): + @pytest.mark.parametrize("use_alternative_expM", [True, False]) + def test_constant_factors_double_exponential(self, use_alternative_expM: bool): r"""Test the computation of propagators for an alpha (double-exponential) kernel with constant coefficients; this tests the block-diagonal computation of propagators.""" indict = {"dynamics": [{"expression": "x'' = -2 * x' - x", "initial_values": {"x": "0", "x'": "0"}}]} - solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True) + solver_dict = odetoolbox.analysis(indict, log_level="DEBUG", disable_stiffness_check=True, use_alternative_expM=use_alternative_expM) assert len(solver_dict) == 1 solver_dict = solver_dict[0] assert solver_dict["solver"] == "analytical"