Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions quantecon/_gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def cartesian(nodes, order='C'):
'''
"""
Cartesian product of a list of arrays

Parameters
Expand All @@ -25,34 +25,28 @@ def cartesian(nodes, order='C'):
-------
out : ndarray(ndim=2)
each line corresponds to one point of the product space
'''
"""


nodes = [np.asarray(e) for e in nodes]
shapes = [e.shape[0] for e in nodes]
shapes = tuple(e.shape[0] for e in nodes)


dtype = np.result_type(*nodes)

n = len(nodes)
l = np.prod(shapes)
out = np.zeros((l, n), dtype=dtype)

if order == 'C':
repetitions = np.cumprod([1] + shapes[:-1])
else:
shapes.reverse()
sh = [1] + shapes[:-1]
repetitions = np.cumprod(sh)
repetitions = repetitions.tolist()
repetitions.reverse()

for i in range(n):
_repeat_1d(nodes[i], repetitions[i], out[:, i])
order_flag = 0 if order == 'C' else 1
# pack nodes as tuple for numba. dtype is inferred by numpy ahead
_cartesian_numba(tuple(nodes), order_flag, shapes, dtype, out)

return out


def mlinspace(a, b, nums, order='C'):
'''
"""
Constructs a regular cartesian grid

Parameters
Expand All @@ -73,7 +67,8 @@ def mlinspace(a, b, nums, order='C'):
-------
out : ndarray(ndim=2)
each line corresponds to one point of the product space
'''
"""


a = np.asarray(a, dtype='float64')
b = np.asarray(b, dtype='float64')
Expand Down Expand Up @@ -433,3 +428,43 @@ def num_compositions_jit(m, n):

"""
return comb_jit(n+m-1, m-1)


@njit(cache=True)
def _cartesian_numba(nodes, order_flag: int, shapes, dtype_num, out):
"""
High-performance implementation of cartesian() using Numba.
Parameters
----------
nodes : tuple of 1d ndarrays
order_flag : int (0 for 'C', 1 for 'F')
shapes : tuple of ints
dtype_num : np dtype typecode (not used with njit)
out : 2d ndarray (preallocated)
Returns
-------
out : ndarray(ndim=2)
"""
n = len(nodes)
if order_flag == 0:
repetitions = np.empty(n, dtype=np.int64)
repetitions[0] = 1
for i in range(1, n):
repetitions[i] = repetitions[i-1] * shapes[i-1]
else: # 'F' order
shapes_rev = np.empty(n, dtype=np.int64)
for i in range(n):
shapes_rev[i] = shapes[n-1-i]
repetitions = np.empty(n, dtype=np.int64)
repetitions[0] = 1
for i in range(1, n):
repetitions[i] = repetitions[i-1] * shapes_rev[i-1]
repetitions_rev = np.empty(n, dtype=np.int64)
for i in range(n):
repetitions_rev[i] = repetitions[n-1 - i]
repetitions = repetitions_rev

for i in range(n):
_repeat_1d(nodes[i], repetitions[i], out[:, i])

return out