diff --git a/quantecon/random/utilities.py b/quantecon/random/utilities.py index f1b53f5d..8ca4b0c7 100644 --- a/quantecon/random/utilities.py +++ b/quantecon/random/utilities.py @@ -4,7 +4,7 @@ """ import numpy as np -from numba import guvectorize, types +from numba import njit, guvectorize, types from numba.extending import overload from ..util import check_random_state, searchsorted @@ -200,13 +200,10 @@ def draw(cdf, size=None): """ if isinstance(size, int): rs = np.random.random(size) - out = np.empty(size, dtype=np.int_) - for i in range(size): - out[i] = searchsorted(cdf, rs[i]) - return out + return _draw_jit(np.asarray(cdf), rs) else: r = np.random.random() - return searchsorted(cdf, r) + return _searchsorted_jit(np.asarray(cdf), r) # Overload for the `draw` function @@ -224,3 +221,24 @@ def draw_impl(cdf, size=None): r = np.random.random() return searchsorted(cdf, r) return draw_impl + + +@njit(cache=True) +def _searchsorted_jit(cdf, v): + lo = 0 + hi = cdf.shape[0] + while lo < hi: + mid = (lo + hi) // 2 + if v < cdf[mid]: + hi = mid + else: + lo = mid + 1 + return lo + +@njit(cache=True) +def _draw_jit(cdf, rs): + size = rs.shape[0] + out = np.empty(size, dtype=np.int_) + for i in range(size): + out[i] = _searchsorted_jit(cdf, rs[i]) + return out