Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 92 additions & 7 deletions quantecon/_compute_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import warnings
import numpy as np
from numba import jit, types
from numba import njit, jit, types
from numba.extending import overload
from .game_theory.lemke_howson import _lemke_howson_tbl, _get_mixed_actions

Expand Down Expand Up @@ -46,7 +46,7 @@ def _is_approx_fp(T, v, error_tol, *args, **kwargs):

def compute_fixed_point(T, v, error_tol=1e-3, max_iter=50, verbose=2,
print_skip=5, method='iteration', *args, **kwargs):
r"""
"""
Computes and returns an approximate fixed point of the function `T`.

The default method `'iteration'` simply iterates the function given
Expand Down Expand Up @@ -126,10 +126,22 @@ def compute_fixed_point(T, v, error_tol=1e-3, max_iter=50, verbose=2,
start_time = time.time()
_print_after_skip(print_skip, it=None)


# optimization: fast path for numpy ndarray and float arrays
use_numba_diff = isinstance(v, np.ndarray) and v.dtype.kind in 'fc'

while True:
new_v = T(v, *args, **kwargs)
iterate += 1
error = np.max(np.abs(new_v - v))

# Use numba optimized difference calculation for ndarray
if use_numba_diff and isinstance(new_v, np.ndarray) and new_v.shape == v.shape:
error = _max_abs_diff(new_v, v)
elif isinstance(new_v, (float, np.floating)) and isinstance(v, (float, np.floating)):
error = _max_abs_diff_scalar(new_v, v)
else:
error = np.max(np.abs(new_v - v))


try:
v[:] = new_v
Expand Down Expand Up @@ -194,7 +206,13 @@ def _compute_fixed_point_ig(T, v, max_iter, verbose, print_skip, is_approx_fp,

if converged or iterate >= max_iter:
if verbose == 2:
error = np.max(np.abs(y_new - x_new))
# optimization: fast path for numpy
if isinstance(x_new, np.ndarray) and isinstance(y_new, np.ndarray) and x_new.shape == y_new.shape:
error = _max_abs_diff(y_new, x_new)
elif isinstance(x_new, (float, np.floating)) and isinstance(y_new, (float, np.floating)):
error = _max_abs_diff_scalar(y_new, x_new)
else:
error = np.max(np.abs(y_new - x_new))
etime = time.time() - start_time
print_skip = 1
_print_after_skip(print_skip, iterate, error, etime)
Expand All @@ -206,7 +224,12 @@ def _compute_fixed_point_ig(T, v, max_iter, verbose, print_skip, is_approx_fp,
return x_new, converged, iterate

if verbose == 2:
error = np.max(np.abs(y_new - x_new))
if isinstance(x_new, np.ndarray) and isinstance(y_new, np.ndarray) and x_new.shape == y_new.shape:
error = _max_abs_diff(y_new, x_new)
elif isinstance(x_new, (float, np.floating)) and isinstance(y_new, (float, np.floating)):
error = _max_abs_diff_scalar(y_new, x_new)
else:
error = np.max(np.abs(y_new - x_new))
etime = time.time() - start_time
_print_after_skip(print_skip, iterate, error, etime)

Expand All @@ -233,7 +256,12 @@ def _compute_fixed_point_ig(T, v, max_iter, verbose, print_skip, is_approx_fp,
break

if verbose == 2:
error = np.max(np.abs(y_new - x_new))
if isinstance(x_new, np.ndarray) and isinstance(y_new, np.ndarray) and x_new.shape == y_new.shape:
error = _max_abs_diff(y_new, x_new)
elif isinstance(x_new, (float, np.floating)) and isinstance(y_new, (float, np.floating)):
error = _max_abs_diff_scalar(y_new, x_new)
else:
error = np.max(np.abs(y_new - x_new))
etime = time.time() - start_time
_print_after_skip(print_skip, iterate, error, etime)

Expand Down Expand Up @@ -269,7 +297,12 @@ def _compute_fixed_point_ig(T, v, max_iter, verbose, print_skip, is_approx_fp,
x_new = (rho @ Y_2d[:m]).reshape(shape_Y[1:])

if verbose == 2:
error = np.max(np.abs(y_new - x_new))
if isinstance(x_new, np.ndarray) and isinstance(y_new, np.ndarray) and x_new.shape == y_new.shape:
error = _max_abs_diff(y_new, x_new)
elif isinstance(x_new, (float, np.floating)) and isinstance(y_new, (float, np.floating)):
error = _max_abs_diff_scalar(y_new, x_new)
else:
error = np.max(np.abs(y_new - x_new))
etime = time.time() - start_time
print_skip = 1
_print_after_skip(print_skip, iterate, error, etime)
Expand Down Expand Up @@ -370,3 +403,55 @@ def _square_sum_array(a): # pragma: no cover
for x in a.flat:
sum_ += x**2
return sum_



@njit(cache=True, fastmath=True)
def _max_abs_diff(new_v: np.ndarray, v: np.ndarray) -> float:
"""
Compute the maximum absolute difference between new_v and v.
Equivalent to np.max(np.abs(new_v - v)), but optimized with numba.
"""
diff = np.abs(new_v - v)
max_diff = 0.0
for i in range(diff.size):
if diff.flat[i] > max_diff:
max_diff = diff.flat[i]
return max_diff

@njit(cache=True, fastmath=True)
def _max_abs_diff_scalar(new_v: float, v: float) -> float:
"""
For scalar floats, just return the absolute difference.
"""
return abs(new_v - v)



@njit(cache=True, fastmath=True)
def _is_approx_fp_numba(T_arr, v_arr, error_tol) -> bool:
"""
numba-accelerated is_approx_fp routine for ndarrays only.
"""
diff = np.abs(T_arr - v_arr)
max_diff = 0.0
for i in range(diff.size):
if diff.flat[i] > max_diff:
max_diff = diff.flat[i]
return max_diff <= error_tol

def _is_approx_fp(T, v, error_tol, *args, **kwargs):
"""
Return True if v is an approximate fixed point for T, i.e.,
max |T(v) - v| <= error_tol.
Optimized for ndarray numeric types using numba njit.
"""
T_v = T(v, *args, **kwargs)
# fast path for numeric ndarrays
if isinstance(v, np.ndarray) and v.dtype.kind in 'fc' and isinstance(T_v, np.ndarray) and T_v.shape == v.shape:
return _is_approx_fp_numba(T_v, v, error_tol)
# handle scalar floats
if isinstance(v, (float, np.floating)) and isinstance(T_v, (float, np.floating)):
return abs(T_v - v) <= error_tol
# fallback to numpy
return np.max(np.abs(T_v - v)) <= error_tol