diff --git a/quantecon/_ivp.py b/quantecon/_ivp.py index 39cf08bd..d6058006 100644 --- a/quantecon/_ivp.py +++ b/quantecon/_ivp.py @@ -15,6 +15,7 @@ """ import numpy as np +from numba import njit from scipy import integrate, interpolate @@ -48,13 +49,18 @@ def __init__(self, f, jac=None): def _integrate_fixed_trajectory(self, h, T, step, relax): """Generates a solution trajectory of fixed length.""" # initialize the solution using initial condition - solution = np.hstack((self.t, self.y)) + t_initial = self.t + y_initial = self.y.copy() # .y is a numpy array + solution = _init_solution(t_initial, y_initial) + while self.successful(): self.integrate(self.t + h, step, relax) - current_step = np.hstack((self.t, self.y)) - solution = np.vstack((solution, current_step)) + t_current = self.t + y_current = self.y.copy() + solution = _append_solution(solution, t_current, y_current) + if (h > 0) and (self.t >= T): break @@ -236,3 +242,19 @@ def interpolate(self, traj, ti, k=3, der=0, ext=2): interp_traj = np.hstack((ti[:, np.newaxis], np.array(out).T)) return interp_traj + +@njit(cache=True) +def _init_solution(t: float, y: np.ndarray) -> np.ndarray: + # np.hstack((self.t, self.y)) but numba only supports appending via np.concatenate + sol = np.empty(y.size + 1, dtype=np.float64) + sol[0] = t + sol[1:] = y + return sol.reshape(1, y.size + 1) + +@njit(cache=True) +def _append_solution(solution: np.ndarray, t: float, y: np.ndarray) -> np.ndarray: + step = np.empty(y.size + 1, dtype=np.float64) + step[0] = t + step[1:] = y + step = step.reshape(1, y.size + 1) + return np.vstack((solution, step))