diff --git a/quantecon/_compute_fp.py b/quantecon/_compute_fp.py index 5b03041f..b80b4dca 100644 --- a/quantecon/_compute_fp.py +++ b/quantecon/_compute_fp.py @@ -6,7 +6,7 @@ import time import warnings import numpy as np -from numba import jit, types +from numba import njit, jit, types from numba.extending import overload from .game_theory.lemke_howson import _lemke_howson_tbl, _get_mixed_actions @@ -40,7 +40,29 @@ def _print_after_skip(skip, it=None, dist=None, etime=None): def _is_approx_fp(T, v, error_tol, *args, **kwargs): - error = np.max(np.abs(T(v, *args, **kwargs) - v)) + result = T(v, *args, **kwargs) + + # Check if we can use the optimized numba function + # Both inputs must be numpy arrays with same shape, dtype, and numeric dtype + can_use_numba = ( + isinstance(result, np.ndarray) and + isinstance(v, np.ndarray) and + result.shape == v.shape and + result.dtype == v.dtype and + np.issubdtype(result.dtype, np.number) and + result.size > 0 # Avoid empty arrays + ) + + if can_use_numba: + try: + error = _numba_max_abs_diff(result.ravel(), v.ravel()) + except: + # If numba fails for any reason, fallback to numpy + error = np.max(np.abs(result - v)) + else: + # Fallback for non-array inputs, dtype/shape mismatch, or empty arrays + error = np.max(np.abs(result - v)) + return error <= error_tol @@ -370,3 +392,14 @@ def _square_sum_array(a): # pragma: no cover for x in a.flat: sum_ += x**2 return sum_ + + +@njit(fastmath=True, cache=True) +def _numba_max_abs_diff(arr1: np.ndarray, arr2: np.ndarray) -> float: + # Calculate the maximum absolute difference between two arrays + max_diff = 0.0 + for i in range(arr1.size): + diff = abs(arr1[i] - arr2[i]) + if diff > max_diff: + max_diff = diff + return max_diff