diff --git a/pettingzoo/classic/tictactoe/board.py b/pettingzoo/classic/tictactoe/board.py index 35186a57a..e6fee6853 100644 --- a/pettingzoo/classic/tictactoe/board.py +++ b/pettingzoo/classic/tictactoe/board.py @@ -1,79 +1,102 @@ +TTT_PLAYER1_WIN = 0 +TTT_PLAYER2_WIN = 1 +TTT_TIE = -1 +TTT_GAME_NOT_OVER = -2 + + class Board: + """Board for a TicTacToe Game. + + This tracks the position and identity of marks on the game board + and allows checking for a winner. + + Example of usage: + + import random + board = Board() + + # random legal moves - for example purposes + def choose_move(board_obj: Board) -> int: + legal_moves = [i for i, mark in enumerate(board_obj.squares) if mark == 0] + return random.choice(legal_moves) + + player = 0 + while True: + move = choose_move(board) + board.play_turn(player, move) + status = board.game_status() + if status != TTT_GAME_NOT_OVER: + if status in [TTT_PLAYER1_WIN, TTT_PLAYER2_WIN]: + print(f"player {status} won") + else: # status == TTT_TIE + print("Tie Game") + break + player = player ^ 1 # swaps between players 0 and 1 + """ + + # indices of the winning lines: vertical(x3), horizontal(x3), diagonal(x2) + winning_combinations = [ + (0, 1, 2), + (3, 4, 5), + (6, 7, 8), + (0, 3, 6), + (1, 4, 7), + (2, 5, 8), + (0, 4, 8), + (2, 4, 6), + ] + def __init__(self): - # internally self.board.squares holds a flat representation of tic tac toe board - # where an empty board is [0, 0, 0, 0, 0, 0, 0, 0, 0] - # where indexes are column wise order + # self.squares holds a flat representation of the tic tac toe board. + # an empty board is [0, 0, 0, 0, 0, 0, 0, 0, 0]. + # player 1's squares are marked 1, while player 2's are marked 2. + # mapping of the flat indices to the 3x3 grid is as follows: # 0 3 6 # 1 4 7 # 2 5 8 - - # empty -- 0 - # player 0 -- 1 - # player 1 -- 2 self.squares = [0] * 9 - # precommute possible winning combinations - self.calculate_winners() + @property + def _n_empty_squares(self): + """The current number of empty squares on the board.""" + return self.squares.count(0) - def setup(self): - self.calculate_winners() + def reset(self): + """Remove all marks from the board.""" + self.squares = [0] * 9 def play_turn(self, agent, pos): - # if spot is empty - if self.squares[pos] != 0: - return - if agent == 0: - self.squares[pos] = 1 - elif agent == 1: - self.squares[pos] = 2 - return - - def calculate_winners(self): - winning_combinations = [] - indices = [x for x in range(0, 9)] - - # Vertical combinations - winning_combinations += [ - tuple(indices[i : (i + 3)]) for i in range(0, len(indices), 3) - ] - - # Horizontal combinations - winning_combinations += [ - tuple(indices[x] for x in range(y, len(indices), 3)) for y in range(0, 3) - ] - - # Diagonal combinations - winning_combinations.append(tuple(x for x in range(0, len(indices), 4))) - winning_combinations.append(tuple(x for x in range(2, len(indices) - 1, 2))) - - self.winning_combinations = winning_combinations - - # returns: - # -1 for no winner - # 1 -- agent 0 wins - # 2 -- agent 1 wins - def check_for_winner(self): - winner = -1 - for combination in self.winning_combinations: - states = [] - for index in combination: - states.append(self.squares[index]) - if all(x == 1 for x in states): - winner = 1 - if all(x == 2 for x in states): - winner = 2 - return winner - - def check_game_over(self): - winner = self.check_for_winner() - - if winner == -1 and all(square in [1, 2] for square in self.squares): - # tie - return True - elif winner in [1, 2]: - return True - else: - return False + """Place a mark by the agent in the spot given. + + The following are required for a move to be valid: + * The agent must be a known agent ID (either 0 or 1). + * The spot must be be empty. + * The spot must be in the board (integer: 0 <= spot <= 8) + + If any of those are not true, an assertion will fail. + """ + assert pos >= 0 and pos <= 8, "Invalid move location" + assert agent in [0, 1], "Invalid agent" + assert self.squares[pos] == 0, "Location is not empty" + + # agent is [0, 1]. board values are stored as [1, 2]. + self.squares[pos] = agent + 1 + + def game_status(self): + """Return status (winner, TTT_TIE if no winner, or TTT_GAME_NOT_OVER).""" + for indices in self.winning_combinations: + states = [self.squares[idx] for idx in indices] + if states == [1, 1, 1]: + return TTT_PLAYER1_WIN + if states == [2, 2, 2]: + return TTT_PLAYER2_WIN + if self._n_empty_squares == 0: + return TTT_TIE + return TTT_GAME_NOT_OVER def __str__(self): return str(self.squares) + + def legal_moves(self): + """Return list of legal moves (as flat indices for spaces on the board).""" + return [i for i, mark in enumerate(self.squares) if mark == 0] diff --git a/pettingzoo/classic/tictactoe/test_board.py b/pettingzoo/classic/tictactoe/test_board.py new file mode 100644 index 000000000..b8f7e9248 --- /dev/null +++ b/pettingzoo/classic/tictactoe/test_board.py @@ -0,0 +1,127 @@ +"""Test cases for TicTacToe board.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from pettingzoo.classic.tictactoe.board import ( # type: ignore + TTT_GAME_NOT_OVER, + TTT_PLAYER1_WIN, + TTT_PLAYER2_WIN, + TTT_TIE, + Board, +) + +# Note: mapping of moves to board positions are: +# 0 3 6 +# 1 4 7 +# 2 5 8 + +agent2_win = { + "moves": [ + # agent_id, position, board after move + (0, 4, [0, 0, 0, 0, 1, 0, 0, 0, 0]), + (1, 0, [2, 0, 0, 0, 1, 0, 0, 0, 0]), + (0, 2, [2, 0, 1, 0, 1, 0, 0, 0, 0]), + (1, 6, [2, 0, 1, 0, 1, 0, 2, 0, 0]), + (0, 3, [2, 0, 1, 1, 1, 0, 2, 0, 0]), + (1, 7, [2, 0, 1, 1, 1, 0, 2, 2, 0]), + (0, 1, [2, 1, 1, 1, 1, 0, 2, 2, 0]), + (1, 8, [2, 1, 1, 1, 1, 0, 2, 2, 2]), # agent 2 wins here + (0, 5, [2, 1, 1, 1, 1, 1, 2, 2, 2]), + ], + "max_step": 7, # should not get past here + "winner": TTT_PLAYER2_WIN, +} + +tie = { + "moves": [ # should be tie + (0, 0, [1, 0, 0, 0, 0, 0, 0, 0, 0]), + (1, 3, [1, 0, 0, 2, 0, 0, 0, 0, 0]), + (0, 1, [1, 1, 0, 2, 0, 0, 0, 0, 0]), + (1, 4, [1, 1, 0, 2, 2, 0, 0, 0, 0]), + (0, 5, [1, 1, 0, 2, 2, 1, 0, 0, 0]), + (1, 2, [1, 1, 2, 2, 2, 1, 0, 0, 0]), + (0, 6, [1, 1, 2, 2, 2, 1, 1, 0, 0]), + (1, 7, [1, 1, 2, 2, 2, 1, 1, 2, 0]), + (0, 8, [1, 1, 2, 2, 2, 1, 1, 2, 1]), + ], + "max_step": 8, + "winner": TTT_TIE, +} + +agent1_win = { + "moves": [ + (0, 0, [1, 0, 0, 0, 0, 0, 0, 0, 0]), + (1, 3, [1, 0, 0, 2, 0, 0, 0, 0, 0]), + (0, 1, [1, 1, 0, 2, 0, 0, 0, 0, 0]), + (1, 4, [1, 1, 0, 2, 2, 0, 0, 0, 0]), + (0, 2, [1, 1, 1, 2, 2, 0, 0, 0, 0]), # agent 1 should win here + (1, 5, [1, 1, 1, 2, 2, 2, 0, 0, 0]), + (0, 6, [1, 1, 1, 2, 2, 2, 1, 0, 0]), + (1, 7, [1, 1, 1, 2, 2, 2, 1, 2, 0]), + (0, 8, [1, 1, 1, 2, 2, 2, 1, 2, 1]), + ], + "max_step": 4, + "winner": TTT_PLAYER1_WIN, +} + + +@pytest.mark.parametrize("values", [agent1_win, agent2_win, tie]) +def test_tictactoe_board_games(values: dict[str, Any]) -> None: + """Test that TicTacToe games go as expected.""" + expected_winner = values["winner"] + max_step = values["max_step"] + + board = Board() + for i, (agent, pos, board_layout) in enumerate(values["moves"]): + assert i <= max_step, "max step exceed in tictactoe game" + board.play_turn(agent, pos) + assert board_layout == board.squares, "wrong tictactoe layout after move" + status = board.game_status() + if status != TTT_GAME_NOT_OVER: + assert i == max_step, "tictactoe game ended on wrong step" + assert status == expected_winner, "wrong winner in tictactoe board test" + break + + +def test_tictactoe_winning_boards() -> None: + """Test that winning board configurations actually win.""" + # these are the winning lines for player 1. Note that moves + # for player 2 are included to make it a legal board. + winning_lines = [ # vertical(x3), horizontal(x3), diagonal(x2) + [1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 1, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1], + [1, 0, 0, 0, 1, 0, 0, 0, 1], + [0, 0, 1, 0, 1, 0, 1, 0, 0], + ] + for line in winning_lines: + board = Board() + board.squares = line + assert board.game_status() == TTT_PLAYER1_WIN, "Bad win check in TicTacToe" + + +def test_tictactoe_bad_move() -> None: + """Test that illegal TicTacToe moves are rejected.""" + board = Board() + # 1) move out of bounds should be rejected + for outside_space in [-1, 9]: + with pytest.raises(AssertionError, match="Invalid move location"): + board.play_turn(0, outside_space) + + # 2) move by unknown agent should be rejected + for unknown_agent in [-1, 2]: + with pytest.raises(AssertionError, match="Invalid agent"): + board.play_turn(unknown_agent, 0) + + # 3) move in occupied space by either agent should be rejected + board.play_turn(0, 4) # this is fine + for agent in [0, 1]: + with pytest.raises(AssertionError, match="Location is not empty"): + board.play_turn(agent, 4) # repeating move is not valid diff --git a/pettingzoo/classic/tictactoe/tictactoe.py b/pettingzoo/classic/tictactoe/tictactoe.py index e68f900a8..e3c219c5a 100644 --- a/pettingzoo/classic/tictactoe/tictactoe.py +++ b/pettingzoo/classic/tictactoe/tictactoe.py @@ -79,11 +79,12 @@ from gymnasium.utils import EzPickle from pettingzoo import AECEnv -from pettingzoo.classic.tictactoe.board import Board +from pettingzoo.classic.tictactoe.board import TTT_GAME_NOT_OVER, TTT_TIE, Board from pettingzoo.utils import AgentSelector, wrappers def get_image(path): + """Return a pygame image loaded from the given path.""" from os import path as os_path cwd = os_path.dirname(__file__) @@ -92,6 +93,7 @@ def get_image(path): def get_font(path, size): + """Return a pygame font loaded from the given path.""" from os import path as os_path cwd = os_path.dirname(__file__) @@ -141,7 +143,7 @@ def __init__( self.rewards = {i: 0 for i in self.agents} self.terminations = {i: False for i in self.agents} self.truncations = {i: False for i in self.agents} - self.infos = {i: {"legal_moves": list(range(0, 9))} for i in self.agents} + self.infos = {i: {} for i in self.agents} self._agent_selector = AgentSelector(self.agents) self.agent_selection = self._agent_selector.reset() @@ -153,42 +155,38 @@ def __init__( if self.render_mode == "human": self.clock = pygame.time.Clock() - # Key - # ---- - # blank space = 0 - # agent 0 = 1 - # agent 1 = 2 - # An observation is list of lists, where each list represents a row - # - # [[0,0,2] - # [1,2,1] - # [2,1,0]] def observe(self, agent): board_vals = np.array(self.board.squares).reshape(3, 3) cur_player = self.possible_agents.index(agent) opp_player = (cur_player + 1) % 2 - cur_p_board = np.equal(board_vals, cur_player + 1) - opp_p_board = np.equal(board_vals, opp_player + 1) - - observation = np.stack([cur_p_board, opp_p_board], axis=2).astype(np.int8) - legal_moves = self._legal_moves() if agent == self.agent_selection else [] + observation = np.empty((3, 3, 2), dtype=np.int8) + # this will give a copy of the board that is 1 for player 1's + # marks and zero for every other square, whether empty or not. + observation[:, :, 0] = np.equal(board_vals, cur_player + 1) + observation[:, :, 1] = np.equal(board_vals, opp_player + 1) - action_mask = np.zeros(9, "int8") - for i in legal_moves: - action_mask[i] = 1 + action_mask = self._get_mask(agent) return {"observation": observation, "action_mask": action_mask} + def _get_mask(self, agent): + action_mask = np.zeros(9, dtype=np.int8) + + # Per the documentation, the mask of any agent other than the + # currently selected one is all zeros. + if agent == self.agent_selection: + for i in self.board.legal_moves(): + action_mask[i] = 1 + + return action_mask + def observation_space(self, agent): return self.observation_spaces[agent] def action_space(self, agent): return self.action_spaces[agent] - def _legal_moves(self): - return [i for i in range(len(self.board.squares)) if self.board.squares[i] == 0] - # action in this case is a value from 0 to 8 indicating position to move on tictactoe board def step(self, action): if ( @@ -196,45 +194,30 @@ def step(self, action): or self.truncations[self.agent_selection] ): return self._was_dead_step(action) - # check if input action is a valid move (0 == empty spot) - assert self.board.squares[action] == 0, "played illegal move" - # play turn - self.board.play_turn(self.agents.index(self.agent_selection), action) - - # update infos - # list of valid actions (indexes in board) - # next_agent = self.agents[(self.agents.index(self.agent_selection) + 1) % len(self.agents)] - next_agent = self._agent_selector.next() - if self.board.check_game_over(): - winner = self.board.check_for_winner() + self.board.play_turn(self.agents.index(self.agent_selection), action) - if winner == -1: - # tie + status = self.board.game_status() + if status != TTT_GAME_NOT_OVER: + if status == TTT_TIE: pass - elif winner == 1: - # agent 0 won - self.rewards[self.agents[0]] += 1 - self.rewards[self.agents[1]] -= 1 else: - # agent 1 won - self.rewards[self.agents[1]] += 1 - self.rewards[self.agents[0]] -= 1 + winner = status # either TTT_PLAYER1_WIN or TTT_PLAYER2_WIN + loser = winner ^ 1 # 0 -> 1; 1 -> 0 + self.rewards[self.agents[winner]] += 1 + self.rewards[self.agents[loser]] -= 1 # once either play wins or there is a draw, game over, both players are done self.terminations = {i: True for i in self.agents} + self._accumulate_rewards() - # Switch selection to next agents - self._cumulative_rewards[self.agent_selection] = 0 - self.agent_selection = next_agent + self.agent_selection = self._agent_selector.next() - self._accumulate_rewards() if self.render_mode == "human": self.render() def reset(self, seed=None, options=None): - # reset environment - self.board = Board() + self.board.reset() self.agents = self.possible_agents[:] self.rewards = {i: 0 for i in self.agents} @@ -244,10 +227,9 @@ def reset(self, seed=None, options=None): self.infos = {i: {} for i in self.agents} # selects the first agent self._agent_selector.reinit(self.agents) - self._agent_selector.reset() self.agent_selection = self._agent_selector.reset() - if self.screen is None: + if self.render_mode is not None and self.screen is None: pygame.init() if self.render_mode == "human": @@ -255,7 +237,7 @@ def reset(self, seed=None, options=None): (self.screen_height, self.screen_height) ) pygame.display.set_caption("Tic-Tac-Toe") - else: + elif self.render_mode == "rgb_array": self.screen = pygame.Surface((self.screen_height, self.screen_height)) def close(self): diff --git a/tutorials/SB3/test/test_sb3_action_mask.py b/tutorials/SB3/test/test_sb3_action_mask.py index de4ee3c07..2be85b1d8 100644 --- a/tutorials/SB3/test/test_sb3_action_mask.py +++ b/tutorials/SB3/test/test_sb3_action_mask.py @@ -91,7 +91,7 @@ def test_action_mask_medium(env_fn): assert ( winrate < 0.75 - ), "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi, 0% for tic-tac-toe + ), "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi # Watch two games (disabled by default) # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)