Skip to content

Commit

Permalink
Add human agent
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarsh committed Feb 26, 2024
1 parent 27a3460 commit e071034
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 2 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ path = "demo/playground.rs"
name = "druid"
path = "demo/druid.rs"

[[bin]]
name = "human"
path = "demo/human.rs"

[[bin]]
name = "hyper"
path = "demo/hyper.rs"
Expand Down
30 changes: 30 additions & 0 deletions demo/human.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use mcts::strategies::human::HumanAgent;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::util;
use mcts::strategies::mcts::SearchConfig;
use mcts::strategies::mcts::TreeSearch;
use mcts::util::battle_royale;

fn main() {
use mcts::games::gonnect::Gonnect;

type TS = TreeSearch<Gonnect<7>, util::Ucb1Grave>;
let mut ts = TS::default()
.config(
SearchConfig::default()
.select(select::Ucb1Grave {
exploration_constant: 1.32,
threshold: 700,
bias: 430.,
current_ref_id: None,
})
.max_iterations(300000)
// .max_time(Duration::from_secs(10))
.expand_threshold(1),
)
.verbose(true);

let mut human = HumanAgent::new();

_ = battle_royale(&mut human, &mut ts);
}
5 changes: 5 additions & 0 deletions src/game.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ pub trait Game: Sized + Clone + Sync + Send {
Self::compute_utilities(term)[Self::player_to_move(init).to_index()]
}

#[allow(unused_variables)]
fn parse_action(state: &Self::S, input: &str) -> Option<Self::A> {
unimplemented!();
}

// #[inline]
// fn rank_to_util(rank: f64, num_players: usize) -> f64 {
// let n = num_players as f64;
Expand Down
49 changes: 47 additions & 2 deletions src/games/gonnect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,38 @@ impl<const N: usize> Game for Gonnect<N> {
}
}

fn parse_action(state: &State<N>, input: &str) -> Option<Self::A> {
let mut chars = input.chars();

if let Some(file) = chars.next() {
let col = file.to_ascii_uppercase() as usize - 'A' as usize;
if col >= 0 && col < N {
if let Ok(row) = chars
.collect::<String>()
.trim()
.parse::<usize>()
.map(|x| x - 1)
{
if row >= 0 && row < N {
let index = BitBoard::<N, N>::to_index(row, col);
let (valid, will_capture) = state.valid(index);
let is_ko = state.is_ko(index, will_capture);
if valid && !is_ko {
return Some(Move(index as u8, will_capture.get_raw()));
} else {
eprintln!("invalid placement: (valid={valid}, is_ko={is_ko})");
}
} else {
eprintln!("row out of range: {row} must be >= 1 and <= {N}");
}
}
} else {
eprintln!("col out of range: {col} must be >= 1 and <= {N}");
}
}
None
}

fn notation(state: &Self::S, action: &Self::A) -> String {
const COL_NAMES: &[u8] = b"ABCDEFGH";
let (row, col) = BitBoard::<N, N>::to_coord(action.0 as usize);
Expand All @@ -185,7 +217,14 @@ impl<const N: usize> Game for Gonnect<N> {

impl<const N: usize> fmt::Display for State<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
const FILES: &[u8] = b"ABCDEFGH";
write!(f, " ")?;
for c in FILES.iter().take(N) {
write!(f, " {}", *c as char)?;
}
writeln!(f)?;
for row in (0..N).rev() {
write!(f, "{}", row + 1)?;
for col in 0..N {
if self.black.get_at(row, col) {
write!(f, " X")?;
Expand All @@ -195,8 +234,14 @@ impl<const N: usize> fmt::Display for State<N> {
write!(f, " .")?;
}
}
write!(f, " {}", row + 1)?;
writeln!(f)?;
}
write!(f, " ")?;
for c in FILES.iter().take(N) {
write!(f, " {}", *c as char)?;
}
writeln!(f)?;
Ok(())
}
}
Expand Down Expand Up @@ -250,7 +295,7 @@ mod tests {

#[test]
fn test_gonnect() {
random_play::<Gonnect<3>>();
random_play::<Gonnect<6>>();
}

#[test]
Expand All @@ -259,7 +304,7 @@ mod tests {
SearchConfig::default()
.expand_threshold(1)
.q_init(crate::strategies::mcts::node::UnvisitedValueEstimate::Draw)
.max_iterations(500),
.max_iterations(20),
);
_ = search.choose_action(&State::default());
render::render(&search);
Expand Down
59 changes: 59 additions & 0 deletions src/strategies/human.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::io;
use std::io::Write;
use std::marker::PhantomData;

use crate::{game::Game, strategies::Search};

pub struct HumanAgent<G: Game> {
name: String,
marker: PhantomData<G>,
}

impl<G: Game> Default for HumanAgent<G> {
fn default() -> Self {
Self::new()
}
}

impl<G: Game> HumanAgent<G> {
pub fn new() -> Self {
Self {
name: "human".into(),
marker: PhantomData,
}
}
}

impl<G: Game> Search for HumanAgent<G>
where
G::S: std::fmt::Display,
{
type G = G;

fn choose_action(&mut self, state: &<Self::G as Game>::S) -> <Self::G as Game>::A {
print!("State is:\n{}", state);
let mut input = String::new();
loop {
input.clear();
print!("> ");
io::stdout().flush().expect("Failed to flush stdout");
match io::stdin().read_line(&mut input) {
Ok(_) => match G::parse_action(state, input.as_str()) {
None => eprintln!("Error parsing input: >{}<", input),
Some(action) => return action,
},
Err(error) => {
eprintln!("Error reading input: {}", error);
}
}
}
}

fn friendly_name(&self) -> String {
self.name.clone()
}

fn set_friendly_name(&mut self, name: &str) {
self.name = name.to_string();
}
}
1 change: 1 addition & 0 deletions src/strategies/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod flat_mc;
pub mod human;
pub mod mcts;
pub mod random;

Expand Down

0 comments on commit e071034

Please sign in to comment.