Skip to content

Commit d6cef8b

Browse files
authored
Reintroduce Open Spiel proxy games and custom UIs (#346)
* Reapply "Add support for game-specific Open Spiel UIs" (#344) This reverts commit b3d5520. * Use relative import in test_open_spiel.py. * Fix html_renderer_callable to avoid late-binding overwriting.
1 parent 51d1919 commit d6cef8b

File tree

11 files changed

+859
-22
lines changed

11 files changed

+859
-22
lines changed

kaggle_environments/envs/open_spiel/__init__.py

Whitespace-only changes.

kaggle_environments/envs/open_spiel/games/__init__.py

Whitespace-only changes.

kaggle_environments/envs/open_spiel/games/connect_four/__init__.py

Whitespace-only changes.
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
function renderer(options) {
2+
const { environment, step, parent, interactive, isInteractive } = options;
3+
4+
const DEFAULT_NUM_ROWS = 6;
5+
const DEFAULT_NUM_COLS = 7;
6+
const PLAYER_SYMBOLS = ['O', 'X']; // O: Player 0 (Yellow), X: Player 1 (Red)
7+
const PLAYER_COLORS = ['#facc15', '#ef4444']; // Yellow for 'O', Red for 'X'
8+
const EMPTY_CELL_COLOR = '#e5e7eb';
9+
const BOARD_COLOR = '#3b82f6';
10+
11+
const SVG_NS = "http://www.w3.org/2000/svg";
12+
const CELL_UNIT_SIZE = 100;
13+
const CIRCLE_RADIUS = CELL_UNIT_SIZE * 0.42;
14+
const SVG_VIEWBOX_WIDTH = DEFAULT_NUM_COLS * CELL_UNIT_SIZE;
15+
const SVG_VIEWBOX_HEIGHT = DEFAULT_NUM_ROWS * CELL_UNIT_SIZE;
16+
17+
let currentBoardSvgElement = null;
18+
let currentStatusTextElement = null;
19+
let currentWinnerTextElement = null;
20+
let currentMessageBoxElement = typeof document !== 'undefined' ? document.getElementById('messageBox') : null;
21+
let currentRendererContainer = null;
22+
let currentTitleElement = null;
23+
24+
function _showMessage(message, type = 'info', duration = 3000) {
25+
if (typeof document === 'undefined' || !document.body) return;
26+
if (!currentMessageBoxElement) {
27+
currentMessageBoxElement = document.createElement('div');
28+
currentMessageBoxElement.id = 'messageBox';
29+
currentMessageBoxElement.style.position = 'fixed';
30+
currentMessageBoxElement.style.top = '10px';
31+
currentMessageBoxElement.style.left = '50%';
32+
currentMessageBoxElement.style.transform = 'translateX(-50%)';
33+
currentMessageBoxElement.style.padding = '0.75rem 1rem';
34+
currentMessageBoxElement.style.borderRadius = '0.375rem';
35+
currentMessageBoxElement.style.boxShadow = '0 2px 4px rgba(0,0,0,0.1)';
36+
currentMessageBoxElement.style.zIndex = '1000';
37+
currentMessageBoxElement.style.opacity = '0';
38+
currentMessageBoxElement.style.transition = 'opacity 0.3s ease-in-out, background-color 0.3s';
39+
currentMessageBoxElement.style.fontSize = '0.875rem';
40+
currentMessageBoxElement.style.fontFamily = "'Inter', sans-serif";
41+
document.body.appendChild(currentMessageBoxElement);
42+
}
43+
currentMessageBoxElement.textContent = message;
44+
currentMessageBoxElement.style.backgroundColor = type === 'error' ? '#ef4444' : '#10b981';
45+
currentMessageBoxElement.style.color = 'white';
46+
currentMessageBoxElement.style.opacity = '1';
47+
setTimeout(() => { if (currentMessageBoxElement) currentMessageBoxElement.style.opacity = '0'; }, duration);
48+
}
49+
50+
function _ensureRendererElements(parentElementToClear, rows, cols) {
51+
if (!parentElementToClear) return false;
52+
parentElementToClear.innerHTML = '';
53+
54+
currentRendererContainer = document.createElement('div');
55+
currentRendererContainer.style.display = 'flex';
56+
currentRendererContainer.style.flexDirection = 'column';
57+
currentRendererContainer.style.alignItems = 'center';
58+
currentRendererContainer.style.padding = '20px';
59+
currentRendererContainer.style.boxSizing = 'border-box';
60+
currentRendererContainer.style.width = '100%';
61+
currentRendererContainer.style.height = '100%';
62+
currentRendererContainer.style.fontFamily = "'Inter', sans-serif";
63+
64+
currentTitleElement = document.createElement('h1');
65+
currentTitleElement.textContent = 'Connect Four';
66+
currentTitleElement.style.fontSize = '1.875rem';
67+
currentTitleElement.style.fontWeight = 'bold';
68+
currentTitleElement.style.marginBottom = '1rem';
69+
currentTitleElement.style.textAlign = 'center';
70+
currentTitleElement.style.color = '#2563eb';
71+
currentRendererContainer.appendChild(currentTitleElement);
72+
73+
currentBoardSvgElement = document.createElementNS(SVG_NS, "svg");
74+
currentBoardSvgElement.setAttribute("viewBox", `0 0 ${SVG_VIEWBOX_WIDTH} ${SVG_VIEWBOX_HEIGHT}`);
75+
currentBoardSvgElement.setAttribute("preserveAspectRatio", "xMidYMid meet");
76+
currentBoardSvgElement.style.width = "auto";
77+
currentBoardSvgElement.style.maxWidth = "500px";
78+
currentBoardSvgElement.style.maxHeight = `calc(100vh - 200px)`;
79+
currentBoardSvgElement.style.aspectRatio = `${cols} / ${rows}`;
80+
currentBoardSvgElement.style.display = "block";
81+
currentBoardSvgElement.style.margin = "0 auto 20px auto";
82+
83+
const boardBgRect = document.createElementNS(SVG_NS, "rect");
84+
boardBgRect.setAttribute("x", "0");
85+
boardBgRect.setAttribute("y", "0");
86+
boardBgRect.setAttribute("width", SVG_VIEWBOX_WIDTH.toString());
87+
boardBgRect.setAttribute("height", SVG_VIEWBOX_HEIGHT.toString());
88+
boardBgRect.setAttribute("fill", BOARD_COLOR);
89+
boardBgRect.setAttribute("rx", (CELL_UNIT_SIZE * 0.1).toString());
90+
currentBoardSvgElement.appendChild(boardBgRect);
91+
92+
// SVG Circles are created with (0,0) being top-left visual circle
93+
for (let r_visual = 0; r_visual < rows; r_visual++) {
94+
for (let c_visual = 0; c_visual < cols; c_visual++) {
95+
const circle = document.createElementNS(SVG_NS, "circle");
96+
const cx = c_visual * CELL_UNIT_SIZE + CELL_UNIT_SIZE / 2;
97+
const cy = r_visual * CELL_UNIT_SIZE + CELL_UNIT_SIZE / 2;
98+
circle.setAttribute("id", `cell-${r_visual}-${c_visual}`);
99+
circle.setAttribute("cx", cx.toString());
100+
circle.setAttribute("cy", cy.toString());
101+
circle.setAttribute("r", CIRCLE_RADIUS.toString());
102+
circle.setAttribute("fill", EMPTY_CELL_COLOR);
103+
currentBoardSvgElement.appendChild(circle);
104+
}
105+
}
106+
currentRendererContainer.appendChild(currentBoardSvgElement);
107+
108+
const statusContainer = document.createElement('div');
109+
statusContainer.style.padding = '10px 15px';
110+
statusContainer.style.backgroundColor = 'white';
111+
statusContainer.style.borderRadius = '8px';
112+
statusContainer.style.boxShadow = '0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -1px rgba(0,0,0,0.06)';
113+
statusContainer.style.textAlign = 'center';
114+
statusContainer.style.width = 'auto';
115+
statusContainer.style.minWidth = '200px';
116+
statusContainer.style.maxWidth = '90vw';
117+
currentRendererContainer.appendChild(statusContainer);
118+
119+
currentStatusTextElement = document.createElement('p');
120+
currentStatusTextElement.style.fontSize = '1.1rem';
121+
currentStatusTextElement.style.fontWeight = '600';
122+
currentStatusTextElement.style.margin = '0 0 5px 0';
123+
statusContainer.appendChild(currentStatusTextElement);
124+
125+
currentWinnerTextElement = document.createElement('p');
126+
currentWinnerTextElement.style.fontSize = '1.25rem';
127+
currentWinnerTextElement.style.fontWeight = '700';
128+
currentWinnerTextElement.style.margin = '5px 0 0 0';
129+
statusContainer.appendChild(currentWinnerTextElement);
130+
131+
parentElementToClear.appendChild(currentRendererContainer);
132+
133+
if (typeof document !== 'undefined' && !document.body.hasAttribute('data-renderer-initialized')) {
134+
document.body.setAttribute('data-renderer-initialized', 'true');
135+
}
136+
return true;
137+
}
138+
139+
function _renderBoardDisplay_svg(gameStateToDisplay, displayRows, displayCols) {
140+
if (!currentBoardSvgElement || !currentStatusTextElement || !currentWinnerTextElement) return;
141+
142+
if (!gameStateToDisplay || typeof gameStateToDisplay.board !== 'object' || !Array.isArray(gameStateToDisplay.board) || gameStateToDisplay.board.length === 0) {
143+
currentStatusTextElement.textContent = "Waiting for game data...";
144+
currentWinnerTextElement.textContent = "";
145+
for (let r_visual = 0; r_visual < displayRows; r_visual++) {
146+
for (let c_visual = 0; c_visual < displayCols; c_visual++) {
147+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${r_visual}-${c_visual}`);
148+
if (circleElement) {
149+
circleElement.setAttribute("fill", EMPTY_CELL_COLOR);
150+
}
151+
}
152+
}
153+
return;
154+
}
155+
156+
const { board, current_player, is_terminal, winner } = gameStateToDisplay;
157+
158+
for (let r_data = 0; r_data < displayRows; r_data++) {
159+
const dataRow = board[r_data];
160+
if (!dataRow || !Array.isArray(dataRow) || dataRow.length !== displayCols) {
161+
// Error handling for malformed row
162+
for (let c_fill = 0; c_fill < displayCols; c_fill++) {
163+
// Determine visual row for error display. If r_data=0 is top data,
164+
// and we want to flip, then this error is for visual row (displayRows-1)-0.
165+
const visual_row_for_error = (displayRows - 1) - r_data;
166+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${visual_row_for_error}-${c_fill}`);
167+
if (circleElement) circleElement.setAttribute("fill", '#FF00FF'); // Magenta for error
168+
}
169+
continue;
170+
}
171+
172+
const visual_svg_row_index = (displayRows - 1) - r_data;
173+
174+
for (let c_data = 0; c_data < displayCols; c_data++) { // c_data iterates through columns of `board[r_data]`
175+
const originalCellValue = dataRow[c_data];
176+
const cellValueForComparison = String(originalCellValue).trim().toLowerCase();
177+
178+
// The column index for SVG is the same as c_data
179+
const visual_svg_col_index = c_data;
180+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${visual_svg_row_index}-${visual_svg_col_index}`);
181+
182+
if (!circleElement) continue;
183+
184+
let fillColor = EMPTY_CELL_COLOR;
185+
if (cellValueForComparison === "o") {
186+
fillColor = PLAYER_COLORS[0]; // Yellow
187+
} else if (cellValueForComparison === "x") {
188+
fillColor = PLAYER_COLORS[1]; // Red
189+
}
190+
circleElement.setAttribute("fill", fillColor);
191+
}
192+
}
193+
194+
currentStatusTextElement.innerHTML = '';
195+
currentWinnerTextElement.innerHTML = '';
196+
if (is_terminal) {
197+
currentStatusTextElement.textContent = "Game Over!";
198+
if (winner !== null && winner !== undefined) {
199+
if (String(winner).toLowerCase() === 'draw') {
200+
currentWinnerTextElement.textContent = "It's a Draw!";
201+
} else {
202+
let winnerSymbolDisplay, winnerColorDisplay;
203+
if (String(winner).toLowerCase() === "o") {
204+
winnerSymbolDisplay = PLAYER_SYMBOLS[0];
205+
winnerColorDisplay = PLAYER_COLORS[0];
206+
} else if (String(winner).toLowerCase() === "x") {
207+
winnerSymbolDisplay = PLAYER_SYMBOLS[1];
208+
winnerColorDisplay = PLAYER_COLORS[1];
209+
}
210+
if (winnerSymbolDisplay) {
211+
currentWinnerTextElement.innerHTML = `Player <span style="color: ${winnerColorDisplay}; font-weight: bold;">${winnerSymbolDisplay}</span> Wins!`;
212+
} else {
213+
currentWinnerTextElement.textContent = `Winner: ${String(winner).toUpperCase()}`;
214+
}
215+
}
216+
} else { currentWinnerTextElement.textContent = "Game ended."; }
217+
} else {
218+
let playerSymbolToDisplay, playerColorToDisplay;
219+
if (String(current_player).toLowerCase() === "o") {
220+
playerSymbolToDisplay = PLAYER_SYMBOLS[0];
221+
playerColorToDisplay = PLAYER_COLORS[0];
222+
} else if (String(current_player).toLowerCase() === "x") {
223+
playerSymbolToDisplay = PLAYER_SYMBOLS[1];
224+
playerColorToDisplay = PLAYER_COLORS[1];
225+
}
226+
if (playerSymbolToDisplay) {
227+
currentStatusTextElement.innerHTML = `Current Player: <span style="color: ${playerColorToDisplay}; font-weight: bold;">${playerSymbolToDisplay}</span>`;
228+
} else {
229+
currentStatusTextElement.textContent = "Waiting for player...";
230+
}
231+
}
232+
}
233+
234+
// --- Main execution logic ---
235+
if (!_ensureRendererElements(parent, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS)) {
236+
if (parent && typeof parent.innerHTML !== 'undefined') {
237+
parent.innerHTML = "<p style='color:red; font-family: sans-serif;'>Critical Error: Renderer element setup failed.</p>";
238+
}
239+
return;
240+
}
241+
242+
if (!environment || !environment.steps || !environment.steps[step]) {
243+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
244+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Initializing environment...";
245+
return;
246+
}
247+
248+
const currentStepAgents = environment.steps[step];
249+
if (!currentStepAgents || !Array.isArray(currentStepAgents) || currentStepAgents.length === 0) {
250+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
251+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Waiting for agent data...";
252+
return;
253+
}
254+
255+
const gameMasterAgentIndex = currentStepAgents.length - 1;
256+
const gameMasterAgent = currentStepAgents[gameMasterAgentIndex];
257+
258+
if (!gameMasterAgent || typeof gameMasterAgent.observation === 'undefined') {
259+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
260+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Waiting for observation data...";
261+
return;
262+
}
263+
const observationForRenderer = gameMasterAgent.observation;
264+
265+
let gameSpecificState = null;
266+
267+
if (observationForRenderer && typeof observationForRenderer.observation_string === 'string' && observationForRenderer.observation_string.trim() !== '') {
268+
try {
269+
gameSpecificState = JSON.parse(observationForRenderer.observation_string);
270+
} catch (e) {
271+
_showMessage("Error: Corrupted game state (obs_string).", 'error');
272+
}
273+
}
274+
275+
if (!gameSpecificState && observationForRenderer && typeof observationForRenderer.json === 'string' && observationForRenderer.json.trim() !== '') {
276+
try {
277+
gameSpecificState = JSON.parse(observationForRenderer.json);
278+
} catch (e) {
279+
_showMessage("Error: Corrupted game state (json).", 'error');
280+
}
281+
}
282+
283+
if (!gameSpecificState && observationForRenderer &&
284+
Array.isArray(observationForRenderer.board) &&
285+
typeof observationForRenderer.current_player !== 'undefined'
286+
) {
287+
if( (observationForRenderer.board.length === DEFAULT_NUM_ROWS &&
288+
(observationForRenderer.board.length === 0 ||
289+
(Array.isArray(observationForRenderer.board[0]) && observationForRenderer.board[0].length === DEFAULT_NUM_COLS)))
290+
){
291+
gameSpecificState = observationForRenderer;
292+
}
293+
}
294+
295+
_renderBoardDisplay_svg(gameSpecificState, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
296+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Change Connect Four state and action string representations."""
2+
3+
import json
4+
from typing import Any
5+
6+
from ... import proxy
7+
import pyspiel
8+
9+
10+
class ConnectFourState(proxy.State):
11+
"""Connect Four state proxy."""
12+
13+
def _player_string(self, player: int) -> str:
14+
if player < 0:
15+
return pyspiel.PlayerId(player).name.lower()
16+
elif player == 0:
17+
return 'x'
18+
elif player == 1:
19+
return 'o'
20+
else:
21+
raise ValueError(f'Invalid player: {player}')
22+
23+
def state_dict(self) -> dict[str, Any]:
24+
# row 0 is now bottom row
25+
rows = reversed(self.to_string().strip().split('\n'))
26+
board = [list(row) for row in rows]
27+
winner = None
28+
if self.is_terminal():
29+
if self.returns()[0] > self.returns()[1]:
30+
winner = 'x'
31+
elif self.returns()[1] > self.returns()[0]:
32+
winner = 'o'
33+
else:
34+
winner = 'draw'
35+
return {
36+
'board': board,
37+
'current_player': self._player_string(self.current_player()),
38+
'is_terminal': self.is_terminal(),
39+
'winner': winner,
40+
}
41+
42+
def to_json(self) -> str:
43+
return json.dumps(self.state_dict())
44+
45+
def action_to_dict(self, action: int) -> dict[str, Any]:
46+
return {'col': action}
47+
48+
def action_to_json(self, action: int) -> str:
49+
return json.dumps(self.action_to_dict(action))
50+
51+
def dict_to_action(self, action_dict: dict[str, Any]) -> int:
52+
return int(action_dict['col'])
53+
54+
def json_to_action(self, action_json: str) -> int:
55+
action_dict = json.loads(action_json)
56+
return self.dict_to_action(action_dict)
57+
58+
def observation_string(self, player: int) -> str:
59+
return self.observation_json(player)
60+
61+
def observation_json(self, player: int) -> str:
62+
del player
63+
return self.to_json()
64+
65+
def __str__(self):
66+
return self.to_json()
67+
68+
69+
class ConnectFourGame(proxy.Game):
70+
"""Connect Four game proxy."""
71+
72+
def __init__(self, params: Any | None = None):
73+
params = params or {}
74+
wrapped = pyspiel.load_game('connect_four', params)
75+
super().__init__(
76+
wrapped,
77+
short_name='connect_four_proxy',
78+
long_name='Connect Four (proxy)',
79+
)
80+
81+
def new_initial_state(self, *args) -> ConnectFourState:
82+
return ConnectFourState(self.__wrapped__.new_initial_state(*args),
83+
game=self)
84+
85+
86+
pyspiel.register_game(ConnectFourGame().get_type(), ConnectFourGame)

0 commit comments

Comments
 (0)