Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions quantecon/game_theory/tests/test_polymatrix_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy import allclose, zeros

import os
from numba import njit

# Mimicing quantecon.tests.util.get_data_dir
data_dir_name = "game_files"
Expand All @@ -22,15 +23,57 @@ def close_normal_form_games(
) -> bool:
if nf1.N != nf2.N:
return False

# Handle empty games (no players) - fall back to original logic
if nf1.N == 0:
for player in range(nf1.N):
if nf1.nums_actions[player] != nf2.nums_actions[player]:
return False
for player in range(nf1.N):
if not allclose(
nf1.players[player].payoff_array,
nf2.players[player].payoff_array,
atol=atol
):
return False
return True

# Check action numbers using numba helper for non-empty games
if not _nums_actions_equal(nf1.nums_actions, nf2.nums_actions):
return False

# Compare payoffs for each player
for player in range(nf1.N):
if nf1.nums_actions[player] != nf2.nums_actions[player]:
# Check if shapes are different first - if so, let numpy handle it like original
arr1 = nf1.players[player].payoff_array
arr2 = nf2.players[player].payoff_array
if arr1.shape != arr2.shape:
# Fall back to original numpy allclose to maintain same error behavior
if not allclose(arr1, arr2, atol=atol):
return False
else:
# Use optimized numba version for same-shape arrays
if not _allclose_ndarray(arr1, arr2, atol):
return False
return True


@njit(cache=True)
def _nums_actions_equal(nums_actions1, nums_actions2):
for i in range(len(nums_actions1)):
if nums_actions1[i] != nums_actions2[i]:
return False
for player in range(nf1.N):
if not allclose(
nf1.players[player].payoff_array,
nf2.players[player].payoff_array,
atol=atol
):
return True

@njit(cache=True)
def _allclose_ndarray(a, b, atol):
# Check shapes first
if a.shape != b.shape:
return False
arr_flat = a.ravel()
brr_flat = b.ravel()
for i in range(arr_flat.shape[0]):
if abs(arr_flat[i] - brr_flat[i]) > atol:
return False
return True

Expand Down