diff --git a/quantecon/_gridtools.py b/quantecon/_gridtools.py index 8915c345..f5a22713 100644 --- a/quantecon/_gridtools.py +++ b/quantecon/_gridtools.py @@ -11,7 +11,7 @@ def cartesian(nodes, order='C'): - ''' + """ Cartesian product of a list of arrays Parameters @@ -25,10 +25,12 @@ def cartesian(nodes, order='C'): ------- out : ndarray(ndim=2) each line corresponds to one point of the product space - ''' + """ + nodes = [np.asarray(e) for e in nodes] - shapes = [e.shape[0] for e in nodes] + shapes = tuple(e.shape[0] for e in nodes) + dtype = np.result_type(*nodes) @@ -36,23 +38,15 @@ def cartesian(nodes, order='C'): l = np.prod(shapes) out = np.zeros((l, n), dtype=dtype) - if order == 'C': - repetitions = np.cumprod([1] + shapes[:-1]) - else: - shapes.reverse() - sh = [1] + shapes[:-1] - repetitions = np.cumprod(sh) - repetitions = repetitions.tolist() - repetitions.reverse() - - for i in range(n): - _repeat_1d(nodes[i], repetitions[i], out[:, i]) + order_flag = 0 if order == 'C' else 1 + # pack nodes as tuple for numba. dtype is inferred by numpy ahead + _cartesian_numba(tuple(nodes), order_flag, shapes, dtype, out) return out def mlinspace(a, b, nums, order='C'): - ''' + """ Constructs a regular cartesian grid Parameters @@ -73,7 +67,8 @@ def mlinspace(a, b, nums, order='C'): ------- out : ndarray(ndim=2) each line corresponds to one point of the product space - ''' + """ + a = np.asarray(a, dtype='float64') b = np.asarray(b, dtype='float64') @@ -433,3 +428,43 @@ def num_compositions_jit(m, n): """ return comb_jit(n+m-1, m-1) + + +@njit(cache=True) +def _cartesian_numba(nodes, order_flag: int, shapes, dtype_num, out): + """ + High-performance implementation of cartesian() using Numba. + Parameters + ---------- + nodes : tuple of 1d ndarrays + order_flag : int (0 for 'C', 1 for 'F') + shapes : tuple of ints + dtype_num : np dtype typecode (not used with njit) + out : 2d ndarray (preallocated) + Returns + ------- + out : ndarray(ndim=2) + """ + n = len(nodes) + if order_flag == 0: + repetitions = np.empty(n, dtype=np.int64) + repetitions[0] = 1 + for i in range(1, n): + repetitions[i] = repetitions[i-1] * shapes[i-1] + else: # 'F' order + shapes_rev = np.empty(n, dtype=np.int64) + for i in range(n): + shapes_rev[i] = shapes[n-1-i] + repetitions = np.empty(n, dtype=np.int64) + repetitions[0] = 1 + for i in range(1, n): + repetitions[i] = repetitions[i-1] * shapes_rev[i-1] + repetitions_rev = np.empty(n, dtype=np.int64) + for i in range(n): + repetitions_rev[i] = repetitions[n-1 - i] + repetitions = repetitions_rev + + for i in range(n): + _repeat_1d(nodes[i], repetitions[i], out[:, i]) + + return out