Skip to content

Commit

Permalink
Add AtariGo
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarsh committed Feb 23, 2024
1 parent 60bfb11 commit 27efae0
Show file tree
Hide file tree
Showing 6 changed files with 581 additions and 149 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "mcts"
version = "0.1.0"
edition = "2021"
default-run = "mcts"
default-run = "druid"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -15,8 +15,8 @@ name = "mcts"
path = "src/lib.rs"

[[bin]]
name = "mcts"
path = "demo/main.rs"
name = "playground"
path = "demo/playground.rs"

[[bin]]
name = "druid"
Expand Down
27 changes: 27 additions & 0 deletions demo/playground.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,32 @@ fn summarize(label_a: &str, label_b: &str, results: Vec<Option<usize>>) {
println!("{label_a} / {label_b}: {win_a} ({pct_a:.2}%) / {win_b} ({pct_b:.2}%) [{draw} draws]");
}

fn atarigo() {
use mcts::games::atarigo;
use mcts::games::atarigo::AtariGo;
use mcts::strategies::mcts::select;

type TS = TreeSearch<AtariGo<4>, util::Ucb1Tuned>;
let mut search = TS::default()
.config(
SearchConfig::default()
.select(select::Ucb1Tuned {
exploration_constant: 1.625,
})
.max_time(Duration::from_secs(5))
.expand_threshold(1),
)
.verbose(true);
let mut state = atarigo::State::default();
println!("state:\n{state}");
while !AtariGo::is_terminal(&state) {
let action = search.choose_action(&state);
state = AtariGo::apply(state, &action);
println!("state:\n{state}");
}
println!("winner: {:?}", state.winner);
}

fn expansion_test() {
use mcts::games::bid_ttt as ttt;
type TS = TreeSearch<ttt::BiddingTicTacToe, util::Ucb1>;
Expand Down Expand Up @@ -243,6 +269,7 @@ fn main() {
color_backtrace::install();
pretty_env_logger::init();

atarigo();
expansion_test();
ucb_test();

Expand Down
168 changes: 154 additions & 14 deletions src/games/atarigo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,67 +6,207 @@ use crate::game::Game;
use crate::game::PlayerIndex;

use serde::Serialize;
use std::fmt;

#[derive(Copy, Clone)]
#[derive(Copy, Clone, Serialize, Debug, Default)]
pub enum Player {
#[default]
Black,
White,
}

impl Player {
fn next(self) -> Player {
match self {
Player::Black => Player::White,
Player::White => Player::Black,
}
}
}

impl PlayerIndex for Player {
fn to_index(&self) -> usize {
*self as usize
}
}

#[derive(Clone, Copy, Serialize, Debug, Hash, PartialEq, Eq)]
pub struct Move(u8);
pub struct Move(u8, u64);

#[derive(Clone, Copy, Serialize, Debug)]
#[derive(Clone, Copy, Serialize, Debug, Default)]
pub struct State<const N: usize> {
black: BitBoard<N, N>,
white: BitBoard<N, N>,
turn: Player,
pub winner: bool,
}

impl<const N: usize> State<N> {
#[inline(always)]
fn occupied(&self) -> BitBoard<N, N> {
self.black | self.white
}

#[inline(always)]
fn player(&self, player: Player) -> BitBoard<N, N> {
match player {
Player::Black => self.black,
Player::White => self.white,
}
}

#[inline(always)]
fn color(&self, index: usize) -> Player {
debug_assert!(self.occupied().get(index));
if self.black.get(index) {
Player::Black
} else {
debug_assert!(self.white.get(index));

Player::White
}
}

fn valid(&self, index: usize) -> (bool, BitBoard<N, N>) {
assert!(!self.occupied().get(index));
let player = self.player(self.turn) | BitBoard::from_index(index);
let opponent = self.player(self.turn.next());
let occupied = player | opponent;
let group = player.flood4(index);
let adjacent = group.adjacency_mask();
let occupied_adjacent = (occupied & adjacent);
let empty_adjacent = !occupied_adjacent;

// If we have adjacent empty positions we still have liberties.
let safe = !(empty_adjacent.is_empty());

let mut seen = BitBoard::empty();
let mut will_capture = BitBoard::empty();
for point in occupied_adjacent {
// By definition, adjacent non-empty points must be the opponent
assert!(occupied.get(point));
assert!(opponent.get(point));
if !seen.get(point) {
let group = opponent.flood4(point);
let adjacent = group.adjacency_mask();
let empty_adjacent = !occupied & adjacent;
if empty_adjacent.is_empty() {
will_capture |= group;
}
seen |= group;
}
}

(safe || !(will_capture.is_empty()), will_capture)
}

#[inline]
fn apply(&mut self, action: &Move) -> Self {
debug_assert!(!self.occupied().get(action.0 as usize));
let player = self.player(self.turn) | BitBoard::from_index(action.0 as usize);
let opponent = self.player(self.turn.next());
match self.turn {
Player::Black => {
self.black = player;
self.white = opponent & (!BitBoard::new(action.1));
}
Player::White => {
self.white = player;
self.black = opponent & (!BitBoard::new(action.1));
}
}
if action.1 > 0 {
self.winner = true;
} else {
self.turn = self.turn.next();
}

*self
}
}

#[derive(Clone)]
struct AtariGo<const N: usize>;
pub struct AtariGo<const N: usize>;

impl<const N: usize> Game for AtariGo<N> {
type S = State<N>;
type A = Move;
type P = Player;

fn apply(state: State<N>, action: &Move) -> State<N> {
// 1. Place piece
// 2. Scan for groups with no liberties and remove
todo!();
fn apply(mut state: State<N>, action: &Move) -> State<N> {
state.apply(action)
}

fn generate_actions(state: &State<N>, actions: &mut Vec<Move>) {
// 1. Most open points are playable
// 2. ...unless they would result in self capture
todo!();
for index in !state.occupied() {
let (valid, will_capture) = state.valid(index);
if valid {
actions.push(Move(index as u8, will_capture.get_raw()))
}
}
}

fn is_terminal(state: &State<N>) -> bool {
todo!();
state.winner
}

fn player_to_move(state: &State<N>) -> Player {
todo!();
state.turn
}

fn winner(state: &State<N>) -> Option<Player> {
todo!();
if state.winner {
Some(state.turn)
} else {
None
}
}

fn notation(state: &Self::S, action: &Self::A) -> String {
const COL_NAMES: &[u8] = b"ABCDEFGH";
let (col, row) = BitBoard::<N, N>::to_coord(action.0 as usize);
format!("{}{}", COL_NAMES[col] as char, row + 1)
}

fn num_players() -> usize {
2
}
}

impl<const N: usize> fmt::Display for State<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for row in 0..N {
for col in 0..N {
if self.black.get_at(N - row - 1, col) {
write!(f, "X")?;
} else if self.white.get_at(N - row - 1, col) {
write!(f, "O")?;
} else {
write!(f, ".")?;
}
}
writeln!(f)?;
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_atarigo() {
let mut state = State::<7>::default();
while !AtariGo::is_terminal(&state) {
println!("state:\n{state}");
let mut actions = Vec::new();
AtariGo::generate_actions(&state, &mut actions);
use rand::Rng;
let mut rng = rand::thread_rng();
assert!(!actions.is_empty());
let idx = rng.gen_range(0..actions.len());
state = AtariGo::apply(state, &actions[idx]);
}
}
}
Loading

0 comments on commit 27efae0

Please sign in to comment.