diff --git a/quantecon/game_theory/tests/test_polymatrix_game.py b/quantecon/game_theory/tests/test_polymatrix_game.py index 56523b4d..1fb187ed 100644 --- a/quantecon/game_theory/tests/test_polymatrix_game.py +++ b/quantecon/game_theory/tests/test_polymatrix_game.py @@ -8,6 +8,7 @@ from numpy import allclose, zeros import os +from numba import njit # Mimicing quantecon.tests.util.get_data_dir data_dir_name = "game_files" @@ -22,15 +23,57 @@ def close_normal_form_games( ) -> bool: if nf1.N != nf2.N: return False + + # Handle empty games (no players) - fall back to original logic + if nf1.N == 0: + for player in range(nf1.N): + if nf1.nums_actions[player] != nf2.nums_actions[player]: + return False + for player in range(nf1.N): + if not allclose( + nf1.players[player].payoff_array, + nf2.players[player].payoff_array, + atol=atol + ): + return False + return True + + # Check action numbers using numba helper for non-empty games + if not _nums_actions_equal(nf1.nums_actions, nf2.nums_actions): + return False + + # Compare payoffs for each player for player in range(nf1.N): - if nf1.nums_actions[player] != nf2.nums_actions[player]: + # Check if shapes are different first - if so, let numpy handle it like original + arr1 = nf1.players[player].payoff_array + arr2 = nf2.players[player].payoff_array + if arr1.shape != arr2.shape: + # Fall back to original numpy allclose to maintain same error behavior + if not allclose(arr1, arr2, atol=atol): + return False + else: + # Use optimized numba version for same-shape arrays + if not _allclose_ndarray(arr1, arr2, atol): + return False + return True + + +@njit(cache=True) +def _nums_actions_equal(nums_actions1, nums_actions2): + for i in range(len(nums_actions1)): + if nums_actions1[i] != nums_actions2[i]: return False - for player in range(nf1.N): - if not allclose( - nf1.players[player].payoff_array, - nf2.players[player].payoff_array, - atol=atol - ): + return True + +@njit(cache=True) +def _allclose_ndarray(a, b, atol): + # Check shapes first + if a.shape != b.shape: + return False + arr_flat = a.ravel() + brr_flat = b.ravel() + for i in range(arr_flat.shape[0]): + if abs(arr_flat[i] - brr_flat[i]) > atol: return False return True