diff --git a/quantecon/util/timing.py b/quantecon/util/timing.py index 5ff9c8dc..1d5d3f97 100644 --- a/quantecon/util/timing.py +++ b/quantecon/util/timing.py @@ -5,6 +5,7 @@ import time import numpy as np from ..timings.timings import get_default_precision +from numba import njit class __Timer__: @@ -144,14 +145,21 @@ def loop_timer(self, n, function, args=None, verbose=True, digits=2, """ tic() all_times = np.empty(n) - for run in range(n): - if hasattr(args, '__iter__'): + # Optimize the loop timing logic with Numba + if hasattr(args, '__iter__'): + # handle iterable args (not including strings) + for run in range(n): function(*args) - elif args is None: + all_times[run] = tac(verbose=False, digits=digits) + elif args is None: + for run in range(n): function() - else: + all_times[run] = tac(verbose=False, digits=digits) + else: + for run in range(n): function(args) - all_times[run] = tac(verbose=False, digits=digits) + all_times[run] = tac(verbose=False, digits=digits) + elapsed = toc(verbose=False, digits=digits) @@ -161,8 +169,9 @@ def loop_timer(self, n, function, args=None, verbose=True, digits=2, print("Total run time: %d:%02d:%0d.%0*d" % (h, m, s, digits, (s % 1)*(10**digits))) - average_time = all_times.mean() - average_of_best = np.sort(all_times)[:best_of].mean() + average_time = _mean_numba(all_times) + average_of_best = _mean_numba(np.sort(all_times)[:best_of]) + if verbose: m, s = divmod(average_time, 60) @@ -456,6 +465,17 @@ def loop_timer(n, function, args=None, verbose=True, digits=2, best_of=3): return __timer__.loop_timer(n, function, args, verbose, digits, best_of) + +# Numba helper for mean calculation +@njit(fastmath=True, cache=True) +def _mean_numba(arr: np.ndarray) -> float: + total = 0.0 + n = arr.size + for i in range(n): + total += arr[i] + return total / n if n > 0 else 0.0 + + # Set docstring _names = ['tic', 'tac', 'toc', 'loop_timer'] _funcs = [eval(name) for name in _names]