Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions quantecon/util/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import numpy as np
from ..timings.timings import get_default_precision
from numba import njit


class __Timer__:
Expand Down Expand Up @@ -144,14 +145,21 @@ def loop_timer(self, n, function, args=None, verbose=True, digits=2,
"""
tic()
all_times = np.empty(n)
for run in range(n):
if hasattr(args, '__iter__'):
# Optimize the loop timing logic with Numba
if hasattr(args, '__iter__'):
# handle iterable args (not including strings)
for run in range(n):
function(*args)
elif args is None:
all_times[run] = tac(verbose=False, digits=digits)
elif args is None:
for run in range(n):
function()
else:
all_times[run] = tac(verbose=False, digits=digits)
else:
for run in range(n):
function(args)
all_times[run] = tac(verbose=False, digits=digits)
all_times[run] = tac(verbose=False, digits=digits)


elapsed = toc(verbose=False, digits=digits)

Expand All @@ -161,8 +169,9 @@ def loop_timer(self, n, function, args=None, verbose=True, digits=2,
print("Total run time: %d:%02d:%0d.%0*d" %
(h, m, s, digits, (s % 1)*(10**digits)))

average_time = all_times.mean()
average_of_best = np.sort(all_times)[:best_of].mean()
average_time = _mean_numba(all_times)
average_of_best = _mean_numba(np.sort(all_times)[:best_of])


if verbose:
m, s = divmod(average_time, 60)
Expand Down Expand Up @@ -456,6 +465,17 @@ def loop_timer(n, function, args=None, verbose=True, digits=2, best_of=3):
return __timer__.loop_timer(n, function, args, verbose, digits, best_of)



# Numba helper for mean calculation
@njit(fastmath=True, cache=True)
def _mean_numba(arr: np.ndarray) -> float:
total = 0.0
n = arr.size
for i in range(n):
total += arr[i]
return total / n if n > 0 else 0.0


# Set docstring
_names = ['tic', 'tac', 'toc', 'loop_timer']
_funcs = [eval(name) for name in _names]
Expand Down