Skip to content

Commit

Permalink
Merge pull request #2 from thomasmarsh/qbf
Browse files Browse the repository at this point in the history
Refactor QBF into MCTS
  • Loading branch information
thomasmarsh authored Feb 20, 2024
2 parents 8e4913b + bc2226c commit f3a48f3
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 393 deletions.
39 changes: 22 additions & 17 deletions demo/book.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,30 @@ use std::sync::Mutex;
use std::time::Duration;

use mcts::game::Game;
use mcts::strategies::mcts::meta::OpeningBook;
use mcts::strategies::mcts::meta::QuasiBestFirst;
use mcts::strategies::mcts::book::OpeningBook;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;
use mcts::strategies::mcts::util;
use mcts::strategies::mcts::SearchConfig;
use mcts::strategies::mcts::TreeSearch;
use mcts::strategies::Search;

use mcts::games::druid::Druid;
use mcts::games::druid::Move;
use mcts::games::druid::State;

use rand::rngs::SmallRng;
use rand_core::SeedableRng;

// QBF Config
const NUM_THREADS: usize = 8;
const NUM_GAMES: usize = 18000;
const EPSILON: f64 = 0.5;

// MCTS Config
const PLAYOUT_DEPTH: usize = 200;
const C_TUNED: f64 = 1.625;
const MAX_ITER: usize = usize::MAX;
const MAX_ITER: usize = 1000; // usize::MAX;
const EXPAND_THRESHOLD: u32 = 1;
const VERBOSE: bool = false;
const MAX_TIME_SECS: u64 = 5; // 0 = infinite
const MAX_TIME_SECS: u64 = 0; // infinite

pub fn debug(book: &OpeningBook<Move>) {
println!("book.len() = {}", book.index.len());
Expand Down Expand Up @@ -63,6 +61,19 @@ fn make_mcts() -> TreeSearch<Druid, util::Ucb1Mast> {
.verbose(VERBOSE)
}

fn make_qbf(book: OpeningBook<Move>) -> TreeSearch<Druid, util::QuasiBestFirst> {
// This is a little crazy.
TreeSearch::default().config(SearchConfig::default().select(select::EpsilonGreedy {
epsilon: EPSILON,
inner: select::QuasiBestFirst {
book,
search: make_mcts(),
..Default::default()
},
..Default::default()
}))
}

fn main() {
color_backtrace::install();

Expand All @@ -73,18 +84,12 @@ fn main() {
for _ in 0..NUM_THREADS {
scope.spawn(|| {
let book = Arc::clone(&book);
for _ in 0..(NUM_GAMES / NUM_THREADS) {
let search = make_mcts();
let mut qbf: QuasiBestFirst<Druid, util::Ucb1Mast> = QuasiBestFirst::new(
book.lock().unwrap().clone(),
search,
SmallRng::from_entropy(),
);

let (stack, utilities) = qbf.search(&State::new());

for _ in 0..(NUM_GAMES / NUM_THREADS) {
let mut ts = make_qbf(book.lock().unwrap().clone());
let (key, utilities) = ts.make_book_entry(&State::new());
let mut book_mut = book.lock().unwrap();
book_mut.add(stack.as_slice(), utilities.as_slice());
book_mut.add(key.as_slice(), utilities.as_slice());
debug(&book_mut);
}
});
Expand Down
5 changes: 4 additions & 1 deletion src/strategies/mcts/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ use super::*;
use crate::game::{Game, PlayerIndex};

pub trait BackpropStrategy: Clone + Sync + Send {
// TODO: cleanup the arguments to this, or just move it to TreeSearch
#[allow(clippy::too_many_arguments)]
fn update<G>(
&self,
ctx: &mut SearchContext<G>,
mut stack: Vec<Id>,
global: &mut TreeStats<G>,
index: &mut TreeIndex<G::A>,
trial: simulate::Trial<G>,
Expand All @@ -22,7 +25,7 @@ pub trait BackpropStrategy: Clone + Sync + Send {
vec![]
};

while let Some(node_id) = ctx.stack.pop() {
while let Some(node_id) = stack.pop() {
let node = index.get(node_id);
let next_action = if !node.is_root() {
Some(node.action(index))
Expand Down
121 changes: 121 additions & 0 deletions src/strategies/mcts/book.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use super::index;
use crate::game::Action;

use rustc_hash::FxHashMap;
use serde::Serialize;

#[derive(Clone, Debug, Serialize)]
pub struct Entry<A: Action> {
pub children: FxHashMap<A, index::Id>,
pub utilities: Vec<f64>,
pub num_visits: u64,
}

impl<A: Action> Entry<A> {
fn update(&mut self, utilities: &[f64]) {
assert_eq!(self.utilities.len(), utilities.len());
self.utilities
.iter_mut()
.enumerate()
.for_each(|(i, score)| {
*score += utilities[i];
});

self.num_visits += 1;
}

fn score(&self, player: usize) -> Option<f64> {
if self.num_visits == 0 {
None
} else {
let q = self.utilities[player];
let n = self.num_visits as f64;
let avg_q = q / n; // -1..1
Some((avg_q + 1.) / 2.)
}
}

fn new(num_players: usize) -> Self {
Self {
children: Default::default(),
utilities: vec![0.; num_players],
num_visits: 0,
}
}
}

#[derive(Clone, Debug)]
pub struct OpeningBook<A: Action> {
pub index: index::Arena<Entry<A>>,
pub root_id: index::Id,
pub num_players: usize,
}

impl<A: Action> OpeningBook<A> {
pub fn new(num_players: usize) -> Self {
let mut index = index::Arena::new();
let root_id = index.insert(Entry::new(num_players));
Self {
index,
root_id,
num_players,
}
}

fn get_mut(&mut self, id: index::Id) -> &mut Entry<A> {
self.index.get_mut(id)
}

fn get(&self, id: index::Id) -> &Entry<A> {
self.index.get(id)
}

fn insert(&mut self, value: Entry<A>) -> index::Id {
self.index.insert(value)
}
}

impl<A: Action> OpeningBook<A> {
fn contains_action(&self, id: index::Id, action: &A) -> bool {
self.index.get(id).children.contains_key(action)
}

// Get or insert a child for this id
fn get_child(&mut self, id: index::Id, action: &A) -> index::Id {
if !self.contains_action(id, action) {
// Insert into index
let child_id = self.insert(Entry::new(self.num_players));

// Place index reference in hash map
self.index
.get_mut(id)
.children
.insert(action.clone(), child_id);
}

// Return the child id
*self.index.get(id).children.get(action).unwrap()
}

pub fn add(&mut self, sequence: &[A], utilities: &[f64]) {
let mut current_id = self.root_id;
self.get_mut(current_id).update(utilities);

sequence.iter().for_each(|action| {
current_id = self.get_child(current_id, action);
self.get_mut(current_id).update(utilities);
});
}

pub fn score(&self, sequence: &[A], player: usize) -> Option<f64> {
let mut current_id = self.root_id;
for action in sequence {
if let Some(child_id) = self.get(current_id).children.get(action) {
current_id = *child_id;
} else {
return None;
}
}
self.get(current_id).score(player)
}
}
Loading

0 comments on commit f3a48f3

Please sign in to comment.