Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions quantecon/_ivp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""
import numpy as np
from numba import njit
from scipy import integrate, interpolate


Expand Down Expand Up @@ -48,13 +49,18 @@ def __init__(self, f, jac=None):
def _integrate_fixed_trajectory(self, h, T, step, relax):
"""Generates a solution trajectory of fixed length."""
# initialize the solution using initial condition
solution = np.hstack((self.t, self.y))
solution = _hstack_numba(self.t, self.y)
solution = solution.reshape(1, -1)

# Preallocate for performance: we collect steps for stacking in Python
steps = []

while self.successful():

self.integrate(self.t + h, step, relax)
current_step = np.hstack((self.t, self.y))
solution = np.vstack((solution, current_step))
current_step = _hstack_numba(self.t, self.y)
steps.append(current_step)


if (h > 0) and (self.t >= T):
break
Expand All @@ -63,24 +69,37 @@ def _integrate_fixed_trajectory(self, h, T, step, relax):
else:
continue


if steps:
# vstack in one go at the end for performance
steps_array = np.vstack(steps)
solution = np.vstack((solution, steps_array))
return solution

def _integrate_variable_trajectory(self, h, g, tol, step, relax):
"""Generates a solution trajectory of variable length."""
# initialize the solution using initial condition
solution = np.hstack((self.t, self.y))
solution = _hstack_numba(self.t, self.y)
solution = solution.reshape(1, -1)

steps = []

while self.successful():

self.integrate(self.t + h, step, relax)
current_step = np.hstack((self.t, self.y))
solution = np.vstack((solution, current_step))
current_step = _hstack_numba(self.t, self.y)
steps.append(current_step)


if g(self.t, self.y, *self.f_params) < tol:
break
else:
continue


if steps:
steps_array = np.vstack(steps)
solution = np.vstack((solution, steps_array))
return solution

def _initialize_integrator(self, t0, y0, integrator, **kwargs):
Expand Down Expand Up @@ -138,7 +157,7 @@ def compute_residual(self, traj, ti, k=3, ext=2):

def solve(self, t0, y0, h=1.0, T=None, g=None, tol=None,
integrator='dopri5', step=False, relax=False, **kwargs):
r"""
"""
Solve the IVP by integrating the ODE given some initial condition.

Parameters
Expand Down Expand Up @@ -236,3 +255,26 @@ def interpolate(self, traj, ti, k=3, der=0, ext=2):
interp_traj = np.hstack((ti[:, np.newaxis], np.array(out).T))

return interp_traj


@njit(cache=True)
def _hstack_numba(t: float, y: np.ndarray) -> np.ndarray:
"""
Optimized version of np.hstack((t, y)) for float t and 1D np.ndarray y.
"""
result = np.empty(y.shape[0] + 1, dtype=y.dtype)
result[0] = t
result[1:] = y
return result

@njit(cache=True)
def _vstack_numba(solution: np.ndarray, current_step: np.ndarray) -> np.ndarray:
"""
Optimized version of np.vstack((solution, current_step))
"""
s_shape = solution.shape
c_shape = current_step.shape
out = np.empty((s_shape[0] + 1, s_shape[1]), dtype=solution.dtype)
out[:-1] = solution
out[-1] = current_step
return out