Skip to content

Commit

Permalink
add random_play
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarsh committed Feb 24, 2024
1 parent db65672 commit d188e9c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 50 deletions.
26 changes: 25 additions & 1 deletion demo/playground.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,35 @@ fn summarize(label_a: &str, label_b: &str, results: Vec<Option<usize>>) {
println!("{label_a} / {label_b}: {win_a} ({pct_a:.2}%) / {win_b} ({pct_b:.2}%) [{draw} draws]");
}

fn knightthrough() {
use mcts::games::knightthrough::Knightthrough;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;

type TS = TreeSearch<Knightthrough<8, 8>, util::Ucb1GraveMast>;
let ts = TS::default()
.config(
SearchConfig::default()
.max_time(Duration::from_secs(10))
.select(select::Ucb1Grave {
exploration_constant: 1.32562,
threshold: 720,
bias: 430.36,
current_ref_id: None,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.98)),
)
.verbose(true);

self_play(ts);
}

fn breakthrough() {
use mcts::games::breakthrough::Breakthrough;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;

type TS = TreeSearch<Breakthrough<8, 8>, util::Ucb1GraveMast>;
type TS = TreeSearch<Breakthrough<6, 4>, util::Ucb1GraveMast>;
let ts = TS::default()
.config(
SearchConfig::default()
Expand Down Expand Up @@ -313,6 +336,7 @@ fn main() {
color_backtrace::install();
pretty_env_logger::init();

knightthrough();
breakthrough();
gonnect();
atarigo();
Expand Down
14 changes: 3 additions & 11 deletions src/games/atarigo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,12 @@ impl<const N: usize> fmt::Display for State<N> {

#[cfg(test)]
mod tests {
use crate::util::random_play;

use super::*;

#[test]
fn test_atarigo() {
let mut state = State::<7>::default();
while !AtariGo::is_terminal(&state) {
println!("state:\n{state}");
let mut actions = Vec::new();
AtariGo::generate_actions(&state, &mut actions);
use rand::Rng;
let mut rng = rand::thread_rng();
assert!(!actions.is_empty());
let idx = rng.gen_range(0..actions.len());
state = AtariGo::apply(state, &actions[idx]);
}
random_play::<AtariGo<7>>();
}
}
43 changes: 17 additions & 26 deletions src/games/breakthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ impl<const N: usize, const M: usize> State<N, M> {
}

#[inline(always)]
fn player(&self, player: Player) -> (BitBoard<N, M>, BitBoard<N, M>) {
fn player(&self, player: Player) -> BitBoard<N, M> {
match player {
Player::Black => (self.black, BitBoard::wall(bitboard::Direction::South)),
Player::White => (self.white, BitBoard::wall(bitboard::Direction::North)),
Player::Black => self.black,
Player::White => self.white,
}
}

Expand All @@ -103,20 +103,19 @@ impl<const N: usize, const M: usize> State<N, M> {
return;
}

let (player, _) = self.player(self.turn);
let (opponent, _) = self.player(self.turn.next());
let (player) = self.player(self.turn);
let (opponent) = self.player(self.turn.next());
let occupied = player | opponent;

for src in player {
// TODO: use mask
let from = BitBoard::from_index(src);
let forward = match self.turn {
Player::Black => from.shift_south(),
Player::White => from.shift_north(),
};

let w = forward.shift_west();
let e = forward.shift_east();
let w = (forward & !BitBoard::wall(bitboard::Direction::West)).shift_west();
let e = (forward & !BitBoard::wall(bitboard::Direction::East)).shift_east();

let available = (w & !player) | (e & !player) | (forward & !occupied);

Expand All @@ -131,21 +130,23 @@ impl<const N: usize, const M: usize> State<N, M> {
debug_assert!(self.occupied().get(action.0 as usize));
let src = BitBoard::from_index(action.0 as usize);
let dst = BitBoard::from_index(action.1 as usize);
let (mut player, goal) = self.player(self.turn);
let mut player = self.player(self.turn);
player |= dst;
player &= !src;
let opponent = self.player(self.turn.next()).0 & !dst;
let opponent = self.player(self.turn.next()) & !dst;

match self.turn {
let goal = match self.turn {
Player::Black => {
self.black = player;
self.white = opponent;
BitBoard::wall(bitboard::Direction::South)
}
Player::White => {
self.white = player;
self.black = opponent;
BitBoard::wall(bitboard::Direction::North)
}
}
};

if player.intersects(goal) {
self.winner = true;
Expand Down Expand Up @@ -203,7 +204,7 @@ impl<const N: usize, const M: usize> Game for Breakthrough<N, M> {
impl<const N: usize, const M: usize> fmt::Display for State<N, M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for row in (0..N).rev() {
for col in 0..N {
for col in 0..M {
if self.black.get_at(row, col) {
write!(f, " X")?;
} else if self.white.get_at(row, col) {
Expand All @@ -220,22 +221,12 @@ impl<const N: usize, const M: usize> fmt::Display for State<N, M> {

#[cfg(test)]
mod tests {
use crate::util::random_play;

use super::*;

#[test]
fn test_breakthrough() {
let mut state = State::<8, 8>::default();
println!("state:\n{state}");
while !Breakthrough::is_terminal(&state) {
let mut actions = Vec::new();
Breakthrough::generate_actions(&state, &mut actions);
use rand::Rng;
let mut rng = rand::thread_rng();
assert!(!actions.is_empty());
let idx = rng.gen_range(0..actions.len());
state = Breakthrough::apply(state, &actions[idx]);
println!("state:\n{state}");
}
println!("winner: {:?}", Breakthrough::winner(&state));
random_play::<Breakthrough<8, 8>>();
}
}
15 changes: 3 additions & 12 deletions src/games/gonnect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,12 @@ impl<const N: usize> fmt::Display for MovesDisplay<N> {

#[cfg(test)]
mod tests {
use crate::util::random_play;

use super::*;

#[test]
fn test_gonnect() {
let mut state = State::<3>::default();
while !Gonnect::is_terminal(&state) {
println!("state: ({:?} to play)\n{state}", state.turn);
println!("moves:\n{}", MovesDisplay(state));
let mut actions = Vec::new();
Gonnect::generate_actions(&state, &mut actions);
use rand::Rng;
let mut rng = rand::thread_rng();
assert!(!actions.is_empty());
let idx = rng.gen_range(0..actions.len());
state = Gonnect::apply(state, &actions[idx]);
}
random_play::<Gonnect<3>>();
}
}

0 comments on commit d188e9c

Please sign in to comment.