From 26afb8ee1d60de3c1f1f93c19788cbdc99d2164c Mon Sep 17 00:00:00 2001 From: Thomas Marsh Date: Mon, 19 Feb 2024 22:33:19 -0500 Subject: [PATCH 1/2] Refactor QBF into MCTS --- demo/book.rs | 37 ++-- src/strategies/mcts/backprop.rs | 5 +- src/strategies/mcts/book.rs | 121 +++++++++++++ src/strategies/mcts/meta.rs | 290 ------------------------------- src/strategies/mcts/mod.rs | 92 +++++++--- src/strategies/mcts/select.rs | 298 +++++++++++++++++++++++++++----- src/strategies/mcts/simulate.rs | 6 +- src/strategies/mcts/util.rs | 35 +++- src/strategies/mod.rs | 8 + 9 files changed, 504 insertions(+), 388 deletions(-) create mode 100644 src/strategies/mcts/book.rs delete mode 100644 src/strategies/mcts/meta.rs diff --git a/demo/book.rs b/demo/book.rs index 41e840e..4709804 100644 --- a/demo/book.rs +++ b/demo/book.rs @@ -3,21 +3,18 @@ use std::sync::Mutex; use std::time::Duration; use mcts::game::Game; -use mcts::strategies::mcts::meta::OpeningBook; -use mcts::strategies::mcts::meta::QuasiBestFirst; +use mcts::strategies::mcts::book::OpeningBook; use mcts::strategies::mcts::select; use mcts::strategies::mcts::simulate; use mcts::strategies::mcts::util; use mcts::strategies::mcts::SearchConfig; use mcts::strategies::mcts::TreeSearch; +use mcts::strategies::Search; use mcts::games::druid::Druid; use mcts::games::druid::Move; use mcts::games::druid::State; -use rand::rngs::SmallRng; -use rand_core::SeedableRng; - // QBF Config const NUM_THREADS: usize = 8; const NUM_GAMES: usize = 18000; @@ -25,10 +22,10 @@ const NUM_GAMES: usize = 18000; // MCTS Config const PLAYOUT_DEPTH: usize = 200; const C_TUNED: f64 = 1.625; -const MAX_ITER: usize = usize::MAX; +const MAX_ITER: usize = 1000; // usize::MAX; const EXPAND_THRESHOLD: u32 = 1; const VERBOSE: bool = false; -const MAX_TIME_SECS: u64 = 5; // 0 = infinite +const MAX_TIME_SECS: u64 = 0; // infinite pub fn debug(book: &OpeningBook) { println!("book.len() = {}", book.index.len()); @@ -63,6 +60,18 @@ fn make_mcts() -> TreeSearch { .verbose(VERBOSE) } +fn make_qbf(book: OpeningBook) -> TreeSearch { + // This is a little crazy. + TreeSearch::default().config(SearchConfig::default().select(select::EpsilonGreedy { + inner: select::QuasiBestFirst { + book, + search: make_mcts(), + ..Default::default() + }, + ..Default::default() + })) +} + fn main() { color_backtrace::install(); @@ -73,18 +82,12 @@ fn main() { for _ in 0..NUM_THREADS { scope.spawn(|| { let book = Arc::clone(&book); - for _ in 0..(NUM_GAMES / NUM_THREADS) { - let search = make_mcts(); - let mut qbf: QuasiBestFirst = QuasiBestFirst::new( - book.lock().unwrap().clone(), - search, - SmallRng::from_entropy(), - ); - - let (stack, utilities) = qbf.search(&State::new()); + for _ in 0..(NUM_GAMES / NUM_THREADS) { + let mut ts = make_qbf(book.lock().unwrap().clone()); + let (key, utilities) = ts.make_book_entry(&State::new()); let mut book_mut = book.lock().unwrap(); - book_mut.add(stack.as_slice(), utilities.as_slice()); + book_mut.add(key.as_slice(), utilities.as_slice()); debug(&book_mut); } }); diff --git a/src/strategies/mcts/backprop.rs b/src/strategies/mcts/backprop.rs index e8edf74..2ef9a5a 100644 --- a/src/strategies/mcts/backprop.rs +++ b/src/strategies/mcts/backprop.rs @@ -3,9 +3,12 @@ use super::*; use crate::game::{Game, PlayerIndex}; pub trait BackpropStrategy: Clone + Sync + Send { + // TODO: cleanup the arguments to this, or just move it to TreeSearch + #[allow(clippy::too_many_arguments)] fn update( &self, ctx: &mut SearchContext, + mut stack: Vec, global: &mut TreeStats, index: &mut TreeIndex, trial: simulate::Trial, @@ -22,7 +25,7 @@ pub trait BackpropStrategy: Clone + Sync + Send { vec![] }; - while let Some(node_id) = ctx.stack.pop() { + while let Some(node_id) = stack.pop() { let node = index.get(node_id); let next_action = if !node.is_root() { Some(node.action(index)) diff --git a/src/strategies/mcts/book.rs b/src/strategies/mcts/book.rs new file mode 100644 index 0000000..12c59f0 --- /dev/null +++ b/src/strategies/mcts/book.rs @@ -0,0 +1,121 @@ +use super::index; +use crate::game::Action; + +use rustc_hash::FxHashMap; +use serde::Serialize; + +#[derive(Clone, Debug, Serialize)] +pub struct Entry { + pub children: FxHashMap, + pub utilities: Vec, + pub num_visits: u64, +} + +impl Entry { + fn update(&mut self, utilities: &[f64]) { + assert_eq!(self.utilities.len(), utilities.len()); + self.utilities + .iter_mut() + .enumerate() + .for_each(|(i, score)| { + *score += utilities[i]; + }); + + self.num_visits += 1; + } + + fn score(&self, player: usize) -> Option { + if self.num_visits == 0 { + None + } else { + let q = self.utilities[player]; + let n = self.num_visits as f64; + let avg_q = q / n; // -1..1 + Some((avg_q + 1.) / 2.) + } + } + + fn new(num_players: usize) -> Self { + Self { + children: Default::default(), + utilities: vec![0.; num_players], + num_visits: 0, + } + } +} + +#[derive(Clone, Debug)] +pub struct OpeningBook { + pub index: index::Arena>, + pub root_id: index::Id, + pub num_players: usize, +} + +impl OpeningBook { + pub fn new(num_players: usize) -> Self { + let mut index = index::Arena::new(); + let root_id = index.insert(Entry::new(num_players)); + Self { + index, + root_id, + num_players, + } + } + + fn get_mut(&mut self, id: index::Id) -> &mut Entry { + self.index.get_mut(id) + } + + fn get(&self, id: index::Id) -> &Entry { + self.index.get(id) + } + + fn insert(&mut self, value: Entry) -> index::Id { + self.index.insert(value) + } +} + +impl OpeningBook { + fn contains_action(&self, id: index::Id, action: &A) -> bool { + self.index.get(id).children.contains_key(action) + } + + // Get or insert a child for this id + fn get_child(&mut self, id: index::Id, action: &A) -> index::Id { + if !self.contains_action(id, action) { + // Insert into index + let child_id = self.insert(Entry::new(self.num_players)); + + // Place index reference in hash map + self.index + .get_mut(id) + .children + .insert(action.clone(), child_id); + } + + // Return the child id + *self.index.get(id).children.get(action).unwrap() + } + + pub fn add(&mut self, sequence: &[A], utilities: &[f64]) { + let mut current_id = self.root_id; + self.get_mut(current_id).update(utilities); + + sequence.iter().for_each(|action| { + current_id = self.get_child(current_id, action); + self.get_mut(current_id).update(utilities); + }); + } + + pub fn score(&self, sequence: &[A], player: usize) -> Option { + let mut current_id = self.root_id; + for action in sequence { + if let Some(child_id) = self.get(current_id).children.get(action) { + current_id = *child_id; + } else { + return None; + } + } + self.get(current_id).score(player) + } +} diff --git a/src/strategies/mcts/meta.rs b/src/strategies/mcts/meta.rs deleted file mode 100644 index c068e51..0000000 --- a/src/strategies/mcts/meta.rs +++ /dev/null @@ -1,290 +0,0 @@ -use super::index; -use super::SearchConfig; -use super::Strategy; -use super::TreeSearch; -use crate::game::Action; -use crate::game::Game; -use crate::game::PlayerIndex; -use crate::strategies::Search; -use crate::util::random_best; - -use rand::rngs::SmallRng; -use rand::Rng; -use rustc_hash::FxHashMap; -use serde::Serialize; - -// This is not mentioned in the Chaslot paper, but QBF seems too greedy -// without epsilon-greedy. -const EPSILON: f64 = 0.3; - -#[derive(Clone, Debug, Serialize)] -pub struct Entry { - pub children: FxHashMap, - pub utilities: Vec, - pub num_visits: u64, -} - -impl Entry { - fn update(&mut self, utilities: &[f64]) { - assert_eq!(self.utilities.len(), utilities.len()); - self.utilities - .iter_mut() - .enumerate() - .for_each(|(i, score)| { - *score += utilities[i]; - }); - - self.num_visits += 1; - } - - fn score(&self, player: usize) -> Option { - if self.num_visits == 0 { - None - } else { - let q = self.utilities[player]; - let n = self.num_visits as f64; - let avg_q = q / n; // -1..1 - Some((avg_q + 1.) / 2.) - } - } - - fn new(num_players: usize) -> Self { - Self { - children: Default::default(), - utilities: vec![0.; num_players], - num_visits: 0, - } - } -} - -#[derive(Clone, Debug)] -pub struct OpeningBook { - pub index: index::Arena>, - pub root_id: index::Id, - pub num_players: usize, -} - -impl OpeningBook { - pub fn new(num_players: usize) -> Self { - let mut index = index::Arena::new(); - let root_id = index.insert(Entry::new(num_players)); - Self { - index, - root_id, - num_players, - } - } - - fn get_mut(&mut self, id: index::Id) -> &mut Entry { - self.index.get_mut(id) - } - - fn get(&self, id: index::Id) -> &Entry { - self.index.get(id) - } - - fn insert(&mut self, value: Entry) -> index::Id { - self.index.insert(value) - } -} - -impl OpeningBook { - fn contains_action(&self, id: index::Id, action: &A) -> bool { - self.index.get(id).children.contains_key(action) - } - - // Get or insert a child for this id - fn get_child(&mut self, id: index::Id, action: &A) -> index::Id { - if !self.contains_action(id, action) { - // Insert into index - let child_id = self.insert(Entry::new(self.num_players)); - - // Place index reference in hash map - self.index - .get_mut(id) - .children - .insert(action.clone(), child_id); - } - - // Return the child id - *self.index.get(id).children.get(action).unwrap() - } - - pub fn add(&mut self, sequence: &[A], utilities: &[f64]) { - let mut current_id = self.root_id; - self.get_mut(current_id).update(utilities); - - sequence.iter().for_each(|action| { - current_id = self.get_child(current_id, action); - self.get_mut(current_id).update(utilities); - }); - } - - pub fn score(&self, sequence: &[A], player: usize) -> Option { - let mut current_id = self.root_id; - for action in sequence { - if let Some(child_id) = self.get(current_id).children.get(action) { - current_id = *child_id; - } else { - return None; - } - } - self.get(current_id).score(player) - } -} - -#[derive(Clone)] -pub struct QuasiBestFirst> { - pub k: Vec, - pub book: OpeningBook, - pub search: TreeSearch, - pub rng: SmallRng, -} - -/// NOTE: this algorithm seems like it could be implemented with the following -/// settings on TreeSearch: -/// -/// - max_iter: 1 -/// - expand_threshold: 0 -/// - select: qbf -/// - backprop: n/a -/// - simulate: n/a -/// -/// Algorithm 1 The “Quasi Best-First” (QBF) algorithm. λ is the number of machines -/// available. K is a constant. g is a game, defined as a sequence of game states. -/// The function “MoGoChoice” asks MOGO to choose a move. -/// -/// ```ignore -/// QBF(K, λ) -/// while True do -/// for l = 1..λ, do -/// s =initial state; g = {s}. -/// while s is not a final state do -/// bestScore = K -/// bestMove = Null -/// for m in the set of possible moves in s do -/// score = percentage of won games by playing the move m in s -/// if score > bestScore then -/// bestScore = score -/// bestMove = m -/// end if -/// end for -/// if bestMove = Null then -/// bestMove = MoGoChoice(s) // lower level MCTS -/// end if -/// s = playMove(s, bestMove) -/// g = concat(g, s) -/// end while -/// Add g and the result of the game in the book. -/// end for -/// end while -/// ``` -impl QuasiBestFirst -where - G: Game, - S: Strategy, - SearchConfig: Default, -{ - pub fn new(book: OpeningBook, search: TreeSearch, rng: SmallRng) -> Self { - // The default value here is 0.5, but the Chaslot paper noted the difficulty - // of elevating the black player in go when cold starting, prompting a lower - // threshold for the initial player. - // TODO: what about N-player games where N > 2 - let mut k = vec![0.5; G::num_players()]; - if k.len() == 2 { - k[0] = 0.1; - } - Self { - k, - book, - search, - rng, - } - } - - /// Search is expected to be called multiple times to fill out the book. - pub fn search(&mut self, init: &G::S) -> (Vec, Vec) { - let mut stack = Vec::new(); - let mut state = init.clone(); - while !G::is_terminal(&state) { - let mut actions = Vec::new(); - G::generate_actions(&state, &mut actions); - let player = G::player_to_move(&state).to_index(); - let index = self.best_child(player, stack.as_slice(), &state); - state = G::apply(state, &actions[index]); - stack.push(actions[index].clone()); - } - - let utilities = G::compute_utilities(&state); - - (stack, utilities) - } - - pub fn debug(&self, init: &G::S) { - println!("book.len() = {}", self.book.index.len()); - let mut actions = Vec::new(); - G::generate_actions(init, &mut actions); - - self.search.index.get(self.search.root_id).actions(); - - let root = self.book.index.get(self.book.root_id); - println!("root: {}", root.num_visits); - actions.iter().enumerate().for_each(|(i, action)| { - let child_id_opt = root.children.get(action); - let child = child_id_opt.map(|child_id| self.book.index.get(*child_id)); - let score = self.book.score(&[action.clone()], 0); - println!( - "- {i}: {:?}, {score:?} {action:?}", - child.map_or(0, |c| c.num_visits), - ); - }); - } - - fn best_child(&mut self, player: usize, stack: &[G::A], state: &G::S) -> usize { - let k_score = self.k[player]; - - // The child actions, enumerated since we plan to return an index. - let mut available = Vec::new(); - G::generate_actions(state, &mut available); - - if self.rng.gen::() < EPSILON { - return self.rng.gen_range(0..available.len()); - } - - // The prefix list of actions we use as a key - let key_init = stack.to_vec(); - - // TODO: a lot of the difficulty here is the handling of optionals. It would make - // sense to have most of the SelectStrategy API return optionals, but it hasn't - // been necessary until this point. Additionally, random_best and random_best_index - // aren't great fits. We are misusing random_best here a bit w.r.t. neg infinity. - let enumerated = available.iter().cloned().enumerate().collect::>(); - let best = random_best( - enumerated.as_slice(), - &mut self.rng, - |(_, action): &(usize, G::A)| { - let mut key = key_init.clone(); - key.push(action.clone()); - - let score = self - .book - .score(key.as_slice(), player) - .unwrap_or(f64::NEG_INFINITY); - if score > k_score { - score - } else { - // NOTE: we depend on random_best using this value internally - // as an equivalence for None types - f64::NEG_INFINITY - } - }, - ); - - if let Some((best_index, _)) = best { - *best_index - } else { - let action = self.search.choose_action(state); - available.iter().position(|p| *p == action.clone()).unwrap() - } - } -} diff --git a/src/strategies/mcts/mod.rs b/src/strategies/mcts/mod.rs index 269f8ae..fe3969b 100644 --- a/src/strategies/mcts/mod.rs +++ b/src/strategies/mcts/mod.rs @@ -1,6 +1,6 @@ pub mod backprop; +pub mod book; pub mod index; -pub mod meta; pub mod node; pub mod select; pub mod simulate; @@ -57,11 +57,11 @@ impl std::ops::BitOr for BackpropFlags { //////////////////////////////////////////////////////////////////////////////// -pub trait Strategy: Clone { - type Select: select::SelectStrategy; +pub trait Strategy: Clone + Sync + Send { + type Select: select::SelectStrategy; type Simulate: simulate::SimulateStrategy; type Backprop: backprop::BackpropStrategy; - type FinalAction: select::SelectStrategy; + type FinalAction: select::SelectStrategy; fn friendly_name() -> String; } @@ -71,7 +71,6 @@ pub struct SearchConfig where G: Game, S: Strategy, - SearchConfig: Sync + Send, { pub select: S::Select, pub simulate: S::Simulate, @@ -88,7 +87,6 @@ impl SearchConfig where G: Game, S: Strategy, - SearchConfig: Sync + Send, { pub fn select(mut self, select: S::Select) -> Self { self.select = select; @@ -144,16 +142,11 @@ where pub struct SearchContext { pub current_id: Id, pub state: G::S, - pub stack: Vec, } impl SearchContext { pub fn new(current_id: Id, state: G::S) -> Self { - Self { - current_id, - state, - stack: vec![], - } + Self { current_id, state } } #[inline] @@ -204,6 +197,8 @@ where pub config: SearchConfig, pub stats: TreeStats, + pub stack: Vec, + pub trial: Option>, pub rng: SmallRng, pub verbose: bool, pub name: String, @@ -213,7 +208,6 @@ impl TreeSearch where G: Game, S: Strategy, - SearchConfig: Sync + Send, { pub fn rng(mut self, rng: SmallRng) -> Self { self.rng = rng; @@ -240,7 +234,7 @@ impl Default for TreeSearch where G: Game, S: Strategy, - SearchConfig: Default + Sync + Send, + SearchConfig: Default, { fn default() -> Self { Self::new(SearchConfig::default(), SmallRng::from_entropy()) @@ -251,7 +245,7 @@ impl TreeSearch where G: Game, S: Strategy, - SearchConfig: Default + Sync + Send, + SearchConfig: Default, { pub fn new(config: SearchConfig, rng: SmallRng) -> Self { let mut index = index::Arena::new(); @@ -260,6 +254,8 @@ where root_id, init_state: None, pv: vec![], + stack: vec![], + trial: None, index, config, rng, @@ -299,7 +295,7 @@ where pub fn select(&mut self, ctx: &mut SearchContext) { let player = G::player_to_move(&ctx.state); loop { - ctx.stack.push(ctx.current_id); + self.stack.push(ctx.current_id); let node = self.index.get(ctx.current_id); if node.is_terminal() || node.stats.num_visits < self.config.expand_threshold { @@ -318,7 +314,10 @@ where let select_ctx = SelectContext { q_init: self.config.q_init, current_id: ctx.current_id, + stack: self.stack.clone(), player: player.to_index(), + player_to_move: G::player_to_move(&ctx.state).to_index(), + state: &ctx.state, index: &self.index, }; self.config.select.best_child(&select_ctx, &mut self.rng) @@ -352,7 +351,7 @@ where ctx.traverse(child_id); ctx.state = state; - ctx.stack.push(ctx.current_id); + self.stack.push(ctx.current_id); if self.config.expand_threshold > 0 { return; @@ -367,7 +366,10 @@ where &SelectContext { q_init: self.config.q_init, current_id: self.root_id, + stack: self.stack.clone(), player: G::player_to_move(state).to_index(), + player_to_move: G::player_to_move(state).to_index(), + state, index: &self.index, }, &mut self.rng, @@ -391,15 +393,23 @@ where } #[inline] - pub(crate) fn backprop(&mut self, ctx: &mut SearchContext, trial: Trial, player: usize) { + pub(crate) fn backprop(&mut self, ctx: &mut SearchContext, player: usize) { self.stats.iter_count += 1; - self.stats.accum_depth += trial.depth + ctx.stack.len() - 1; + self.stats.accum_depth += self.trial.as_ref().unwrap().depth + self.stack.len() - 1; let flags = self.config.select.backprop_flags() | self.config.simulate.backprop_flags(); self.config .backprop // TODO: may as well pass &mut self? Seems like the separation // of concerns is not ideal. - .update(ctx, &mut self.stats, &mut self.index, trial, player, flags); + .update( + ctx, + self.stack.clone(), + &mut self.stats, + &mut self.index, + self.trial.as_ref().unwrap().clone(), + player, + flags, + ); } #[allow(dead_code)] @@ -475,6 +485,7 @@ where fn reset(&mut self) -> Id { self.index.clear(); + self.stack.clear(); self.stats.accum_depth = 0; self.stats.iter_count = 0; self.new_root() @@ -491,6 +502,9 @@ where q_init: self.config.q_init, current_id: node_id, player: player.to_index(), + stack: self.stack.clone(), + state: &state, + player_to_move: player.to_index(), index: &self.index, }; let best_idx = self @@ -514,11 +528,7 @@ impl super::Search for TreeSearch where G: Game, S: Strategy, - SearchConfig: Default + Sync + Send, - >::Select: Sync + Send, - >::FinalAction: Sync + Send, - >::Backprop: Sync + Send, - >::Simulate: Sync + Send, + SearchConfig: Default, { type G = G; @@ -538,8 +548,8 @@ where } let mut ctx = SearchContext::new(root_id, state.clone()); self.select(&mut ctx); - let trial = self.simulate(&ctx.state, G::player_to_move(state).to_index()); - self.backprop(&mut ctx, trial, G::player_to_move(state).to_index()); + self.trial = Some(self.simulate(&ctx.state, G::player_to_move(state).to_index())); + self.backprop(&mut ctx, G::player_to_move(state).to_index()); } self.compute_pv(); @@ -549,10 +559,36 @@ where // // max_iterations < expand_threshold // - // TODO: We might check for this and unconditionally expand root. + // TODO: We might check for this and unconditionally expand root. I think + // a lot of implementations fully expand root on the first iteration. self.select_final_action(state) } + fn make_book_entry( + &mut self, + state: &::S, + ) -> (Vec<::A>, Vec) { + assert_eq!(self.config.expand_threshold, 0); + assert_eq!(self.config.max_iterations, 1); + + // Run the search, with expand_threshold == 0, so we fully expand to the + // terminal node. + _ = self.choose_action(state); + + // The stack now contains the action path to the terminal state. + let actions = self + .stack + .iter() + .skip(1) + .cloned() + .map(|id| self.index.get(id).action(&self.index)) + .collect(); + + let utilities = G::compute_utilities(&self.trial.as_ref().unwrap().state); + + (actions, utilities) + } + fn estimated_depth(&self) -> usize { (self.stats.accum_depth as f64 / self.stats.iter_count as f64).round() as usize } diff --git a/src/strategies/mcts/select.rs b/src/strategies/mcts/select.rs index 7ab1f1d..f217719 100644 --- a/src/strategies/mcts/select.rs +++ b/src/strategies/mcts/select.rs @@ -3,36 +3,40 @@ use std::sync::atomic::Ordering::Relaxed; use rand::rngs::SmallRng; use super::*; -use crate::game::Action; +use crate::game::Game; +use crate::strategies::Search; -pub struct SelectContext<'a, A: Action> { +pub struct SelectContext<'a, G: Game> { pub q_init: node::UnvisitedValueEstimate, pub current_id: index::Id, + pub stack: Vec, + pub state: &'a G::S, pub player: usize, - pub index: &'a TreeIndex, + pub player_to_move: usize, + pub index: &'a TreeIndex, } //////////////////////////////////////////////////////////////////////////////// -pub trait SelectStrategy: Sized + Clone + Sync + Send { +pub trait SelectStrategy: Sized + Clone + Sync + Send { type Score: PartialOrd + Copy; type Aux: Copy; /// If the strategy wants to lift any calculations out of the inner select /// loop, then they can provide this here. - fn setup(&mut self, ctx: &SelectContext<'_, A>) -> Self::Aux; + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> Self::Aux; /// Default implementation should be sufficient for all cases. - fn best_child(&mut self, ctx: &SelectContext<'_, A>, rng: &mut SmallRng) -> usize { + fn best_child(&mut self, ctx: &SelectContext<'_, G>, rng: &mut SmallRng) -> usize { let current = ctx.index.get(ctx.current_id); random_best_index(current.children(), self, ctx, rng) } /// Given a child index, calculate a score. - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, aux: Self::Aux) -> Self::Score; + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, aux: Self::Aux) -> Self::Score; /// Provide a score for any value that is not yet visited. - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, aux: Self::Aux) -> Self::Score; + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, aux: Self::Aux) -> Self::Score; fn backprop_flags(&self) -> BackpropFlags { BackpropFlags(0) @@ -41,6 +45,64 @@ pub trait SelectStrategy: Sized + Clone + Sync + Send { //////////////////////////////////////////////////////////////////////////////// +#[derive(Clone)] +pub struct EpsilonGreedy> { + pub epsilon: f64, + pub inner: S, + pub marker: std::marker::PhantomData, +} + +impl Default for EpsilonGreedy +where + G: Game, + S: SelectStrategy + Default, +{ + fn default() -> Self { + Self { + epsilon: 0.1, + inner: S::default(), + marker: std::marker::PhantomData, + } + } +} + +impl SelectStrategy for EpsilonGreedy +where + G: Game, + S: SelectStrategy, +{ + type Score = S::Score; + type Aux = S::Aux; + + fn best_child(&mut self, ctx: &SelectContext<'_, G>, rng: &mut SmallRng) -> usize { + if rng.gen::() < self.epsilon { + let current = ctx.index.get(ctx.current_id); + let n = current.children().len(); + rng.gen_range(0..n) + } else { + self.inner.best_child(ctx, rng) + } + } + + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> Self::Aux { + self.inner.setup(ctx) + } + + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, aux: Self::Aux) -> Self::Score { + self.inner.score_child(ctx, child_id, aux) + } + + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, aux: Self::Aux) -> Self::Score { + self.inner.unvisited_value(ctx, aux) + } + + fn backprop_flags(&self) -> BackpropFlags { + self.inner.backprop_flags() + } +} + +//////////////////////////////////////////////////////////////////////////////// + const PRIMES: [usize; 16] = [ 14323, 18713, 19463, 30553, 33469, 45343, 50221, 51991, 53201, 56923, 64891, 72763, 74471, 81647, 92581, 94693, @@ -48,15 +110,15 @@ const PRIMES: [usize; 16] = [ // This function is adapted from from minimax-rs. #[inline] -fn random_best_index( +fn random_best_index( set: &[Option], strategy: &mut S, - ctx: &SelectContext<'_, A>, + ctx: &SelectContext<'_, G>, rng: &mut SmallRng, ) -> usize where - S: SelectStrategy, - A: Action, + S: SelectStrategy, + G: Game, { // To make the choice more uniformly random among the best moves, start // at a random offset and stride by a random amount. The stride must be @@ -102,15 +164,15 @@ where #[derive(Default, Clone)] pub struct RobustChild; -impl SelectStrategy for RobustChild { +impl SelectStrategy for RobustChild { type Score = (i64, f64); type Aux = (); #[inline(always)] - fn setup(&mut self, _: &SelectContext<'_, A>) -> Self::Aux {} + fn setup(&mut self, _: &SelectContext<'_, G>) -> Self::Aux {} #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, _: Self::Aux) -> (i64, f64) { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, _: Self::Aux) -> (i64, f64) { let child = ctx.index.get(child_id); ( child.stats.num_visits as i64, @@ -119,7 +181,7 @@ impl SelectStrategy for RobustChild { } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, _: Self::Aux) -> (i64, f64) { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, _: Self::Aux) -> (i64, f64) { let q = ctx .index .get(ctx.current_id) @@ -136,20 +198,20 @@ impl SelectStrategy for RobustChild { #[derive(Default, Clone)] pub struct MaxAvgScore; -impl SelectStrategy for MaxAvgScore { +impl SelectStrategy for MaxAvgScore { type Score = f64; type Aux = (); #[inline(always)] - fn setup(&mut self, _: &SelectContext<'_, A>) -> Self::Aux {} + fn setup(&mut self, _: &SelectContext<'_, G>) -> Self::Aux {} #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, _: Self::Aux) -> f64 { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, _: Self::Aux) -> f64 { ctx.index.get(child_id).stats.expected_score(ctx.player) } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, _: Self::Aux) -> f64 { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, _: Self::Aux) -> f64 { ctx.index .get(ctx.current_id) .stats @@ -172,15 +234,15 @@ impl Default for SecureChild { } } -impl SelectStrategy for SecureChild { +impl SelectStrategy for SecureChild { type Score = f64; type Aux = (); #[inline(always)] - fn setup(&mut self, _: &SelectContext<'_, A>) -> Self::Aux {} + fn setup(&mut self, _: &SelectContext<'_, G>) -> Self::Aux {} #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, _: Self::Aux) -> f64 { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, _: Self::Aux) -> f64 { let child = ctx.index.get(child_id); let q = child.stats.expected_score(ctx.player); let n = child.stats.num_visits + child.stats.num_visits_virtual.load(Relaxed); @@ -189,7 +251,7 @@ impl SelectStrategy for SecureChild { } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, _: Self::Aux) -> f64 { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, _: Self::Aux) -> f64 { ctx.index .get(ctx.current_id) .stats @@ -213,18 +275,18 @@ impl Default for Ucb1 { } } -impl SelectStrategy for Ucb1 { +impl SelectStrategy for Ucb1 { type Score = f64; type Aux = f64; #[inline(always)] - fn setup(&mut self, ctx: &SelectContext<'_, A>) -> f64 { + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> f64 { let current = ctx.index.get(ctx.current_id); ((current.stats.num_visits as f64).max(1.)).ln() } #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, parent_log: f64) -> f64 { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, parent_log: f64) -> f64 { let child = ctx.index.get(child_id); let exploit = child.stats.exploitation_score(ctx.player); let num_visits = child.stats.num_visits + child.stats.num_visits_virtual.load(Relaxed); @@ -233,7 +295,7 @@ impl SelectStrategy for Ucb1 { } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, parent_log: f64) -> f64 { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, parent_log: f64) -> f64 { let current = ctx.index.get(ctx.current_id); let unvisited_value = current .stats @@ -272,18 +334,18 @@ fn ucb1_tuned( + exploration_constant * visits_fraction.sqrt()) } -impl SelectStrategy for Ucb1Tuned { +impl SelectStrategy for Ucb1Tuned { type Score = f64; type Aux = f64; #[inline(always)] - fn setup(&mut self, ctx: &SelectContext<'_, A>) -> f64 { + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> f64 { let current = ctx.index.get(ctx.current_id); ((current.stats.num_visits as f64).max(1.)).ln() } #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, parent_log: f64) -> f64 { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, parent_log: f64) -> f64 { let child = ctx.index.get(child_id); let exploit = child.stats.exploitation_score(ctx.player); let num_visits = child.stats.num_visits + child.stats.num_visits_virtual.load(Relaxed); @@ -301,7 +363,7 @@ impl SelectStrategy for Ucb1Tuned { } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, parent_log: f64) -> Self::Score { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, parent_log: f64) -> Self::Score { let current = ctx.index.get(ctx.current_id); let unvisited_value = current .stats @@ -341,12 +403,12 @@ fn grave_value(beta: f64, mean_score: f64, mean_amaf: f64) -> f64 { (1. - beta) * mean_score + beta * mean_amaf } -impl SelectStrategy for McGrave { +impl SelectStrategy for McGrave { type Score = f64; type Aux = (); #[inline(always)] - fn setup(&mut self, ctx: &SelectContext<'_, A>) -> Self::Aux { + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> Self::Aux { let current = ctx.index.get(ctx.current_id); if self.current_ref_id.is_none() @@ -358,7 +420,7 @@ impl SelectStrategy for McGrave { } #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, _: Self::Aux) -> f64 { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, _: Self::Aux) -> f64 { let t = ctx.index.get(child_id); let tref = ctx.index.get(self.current_ref_id.unwrap()); let p = (t.stats.num_visits + t.stats.num_visits_virtual.load(Relaxed)) as f64; @@ -378,7 +440,7 @@ impl SelectStrategy for McGrave { } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, _: Self::Aux) -> f64 { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, _: Self::Aux) -> f64 { ctx.index .get(ctx.current_id) .stats @@ -403,15 +465,15 @@ impl Default for McBrave { } } -impl SelectStrategy for McBrave { +impl SelectStrategy for McBrave { type Score = f64; type Aux = (); #[inline(always)] - fn setup(&mut self, _: &SelectContext<'_, A>) -> Self::Aux {} + fn setup(&mut self, _: &SelectContext<'_, G>) -> Self::Aux {} #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, _: Self::Aux) -> Self::Score { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, _: Self::Aux) -> Self::Score { let current = ctx.index.get(ctx.current_id); let unvisited_value = current .stats @@ -420,7 +482,7 @@ impl SelectStrategy for McBrave { } #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, _: Self::Aux) -> Self::Score { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, _: Self::Aux) -> Self::Score { let child = ctx.index.get(child_id); let mean_score = child.stats.exploitation_score(ctx.player); @@ -484,18 +546,18 @@ impl Default for ScalarAmaf { } } -impl SelectStrategy for ScalarAmaf { +impl SelectStrategy for ScalarAmaf { type Score = f64; type Aux = f64; #[inline(always)] - fn setup(&mut self, ctx: &SelectContext<'_, A>) -> f64 { + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> f64 { let current = ctx.index.get(ctx.current_id); ((current.stats.num_visits as f64).max(1.)).ln() } #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, parent_log: f64) -> f64 { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, parent_log: f64) -> f64 { let child = ctx.index.get(child_id); let exploit = child.stats.exploitation_score(ctx.player); let num_visits = child.stats.num_visits + child.stats.num_visits_virtual.load(Relaxed); @@ -514,7 +576,7 @@ impl SelectStrategy for ScalarAmaf { } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, _: f64) -> f64 { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, _: f64) -> f64 { let current = ctx.index.get(ctx.current_id); current .stats @@ -560,12 +622,12 @@ fn ucb1_grave_value( grave_value(beta, mean_score, mean_amaf) + exploration_constant * explore } -impl SelectStrategy for Ucb1Grave { +impl SelectStrategy for Ucb1Grave { type Score = f64; type Aux = f64; #[inline(always)] - fn setup(&mut self, ctx: &SelectContext<'_, A>) -> f64 { + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> f64 { let current = ctx.index.get(ctx.current_id); if self.current_ref_id.is_none() || current.stats.num_visits > self.threshold @@ -578,7 +640,7 @@ impl SelectStrategy for Ucb1Grave { } #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, A>, child_id: Id, parent_log: f64) -> f64 { + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, parent_log: f64) -> f64 { let current_ref = ctx.index.get(self.current_ref_id.unwrap()); let child = ctx.index.get(child_id); let mean_score = child.stats.exploitation_score(ctx.player); @@ -614,7 +676,7 @@ impl SelectStrategy for Ucb1Grave { } #[inline(always)] - fn unvisited_value(&self, ctx: &SelectContext<'_, A>, parent_log: f64) -> Self::Score { + fn unvisited_value(&self, ctx: &SelectContext<'_, G>, parent_log: f64) -> Self::Score { let current = ctx.index.get(ctx.current_id); let unvisited_value = current .stats @@ -633,3 +695,143 @@ impl SelectStrategy for Ucb1Grave { BackpropFlags(GRAVE) } } + +//////////////////////////////////////////////////////////////////////////////// + +/// Quasi Best-First comes from the Chaslot paper on Meta MCTS for opening book +/// generation. This is intended to be used differently than other strategies. +/// For opening book generation, we use the following settings for the higher +/// level MCTS config: +/// +/// - expand_threshold: 0 (expand to terminal state during select) +/// - max_iterations: 1 (we only need one PV) +/// - simulate: n/a (ignored, due to max_iteration count) +/// - backprop: n/a (ignored, due to max_iteration count) +/// +/// We add an epsilon-greedy parameter since this seems otherwise too greedy +/// a selection strategy and we don't see enough exploration. +/// +/// +/// > Algorithm 1 The “Quasi Best-First” (QBF) algorithm. λ is the number of machines +/// > available. K is a constant. g is a game, defined as a sequence of game states. +/// > The function “MoGoChoice” asks MOGO to choose a move. +/// +/// ```ignore +/// QBF(K, λ) +/// while True do +/// for l = 1..λ, do +/// s =initial state; g = {s}. +/// while s is not a final state do +/// bestScore = K +/// bestMove = Null +/// for m in the set of possible moves in s do +/// score = percentage of won games by playing the move m in s +/// if score > bestScore then +/// bestScore = score +/// bestMove = m +/// end if +/// end for +/// if bestMove = Null then +/// bestMove = MoGoChoice(s) // lower level MCTS +/// end if +/// s = playMove(s, bestMove) +/// g = concat(g, s) +/// end while +/// Add g and the result of the game in the book. +/// end for +/// end while +/// ``` +#[derive(Clone)] +pub struct QuasiBestFirst> { + pub book: book::OpeningBook, + pub search: TreeSearch, + pub epsilon: f64, + pub k: Vec, + pub key_init: Vec, +} + +impl Default for QuasiBestFirst +where + G: Game, + S: Strategy, + TreeSearch: Default, +{ + fn default() -> Self { + // The default value here is 0.5, but the Chaslot paper noted the difficulty + // of elevating the black player in go when cold starting, prompting a lower + // threshold for the initial player. + // TODO: what about N-player games where N > 2 + let mut k = vec![0.5; G::num_players()]; + if k.len() == 2 { + k[0] = 0.1; + } + + Self { + book: book::OpeningBook::new(G::num_players()), + search: TreeSearch::default(), + epsilon: 0.3, + k, + key_init: vec![], + } + } +} + +impl SelectStrategy for QuasiBestFirst +where + G: Game, + S: Strategy, + SearchConfig: Default, +{ + type Score = f64; + type Aux = (); + + fn best_child(&mut self, ctx: &SelectContext<'_, G>, rng: &mut SmallRng) -> usize { + let current = ctx.index.get(ctx.current_id); + let best = random_best_index(current.children(), self, ctx, rng); + let children = current.children(); + let score = children[best].map(|child_id| self.score_child(ctx, child_id, ())); + + if score.is_none() { + let action = self.search.choose_action(ctx.state); + current.actions().iter().position(|a| *a == action).unwrap() + } else { + best + } + } + + #[inline(always)] + fn setup(&mut self, ctx: &SelectContext<'_, G>) -> Self::Aux { + self.key_init = ctx + .stack + .iter() + .skip(1) + .map(|id| ctx.index.get(*id).action(ctx.index)) + .collect(); + } + + #[inline(always)] + fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, _: Self::Aux) -> f64 { + let child = ctx.index.get(child_id); + let action = child.action(ctx.index); + let mut key = self.key_init.clone(); + key.push(action.clone()); + + let k_score = self.k[ctx.player_to_move]; + let score = self + .book + .score(key.as_slice(), ctx.player_to_move) + .unwrap_or(f64::NEG_INFINITY); + if score > k_score { + score + } else { + // NOTE: we depend on random_best using this value internally + // as an equivalence for None types + f64::NEG_INFINITY + } + } + + #[inline(always)] + fn unvisited_value(&self, _: &SelectContext<'_, G>, _: Self::Aux) -> f64 { + f64::NEG_INFINITY + } +} diff --git a/src/strategies/mcts/simulate.rs b/src/strategies/mcts/simulate.rs index 731af57..be86fc9 100644 --- a/src/strategies/mcts/simulate.rs +++ b/src/strategies/mcts/simulate.rs @@ -8,19 +8,19 @@ use crate::game::Game; use crate::strategies::Search; use crate::util::random_best; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum EndType { NaturalEnd, // MoveLimit, TurnLimit, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Status { pub end_type: Option, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Trial { pub actions: Vec, pub state: G::S, diff --git a/src/strategies/mcts/util.rs b/src/strategies/mcts/util.rs index d3055a5..21ca67d 100644 --- a/src/strategies/mcts/util.rs +++ b/src/strategies/mcts/util.rs @@ -298,10 +298,43 @@ impl Default for SearchConfig { impl Strategy for MetaMcts { type Select = select::Ucb1; type Simulate = simulate::MetaMcts; - type FinalAction = select::MaxAvgScore; type Backprop = backprop::Classic; + type FinalAction = select::MaxAvgScore; fn friendly_name() -> String { "meta-mcts".into() } } + +#[derive(Clone)] +pub struct QuasiBestFirst; + +impl Strategy for QuasiBestFirst { + type Select = select::EpsilonGreedy>; + type Simulate = simulate::Uniform; + type Backprop = backprop::Classic; + type FinalAction = select::MaxAvgScore; + + fn friendly_name() -> String { + "qbf/ucb1+mast".into() + } +} + +impl Default for SearchConfig { + fn default() -> Self { + Self { + select: select::EpsilonGreedy { + epsilon: 0.3, + ..Default::default() + }, + simulate: Default::default(), + backprop: Default::default(), + final_action: Default::default(), + q_init: node::UnvisitedValueEstimate::Parent, + expand_threshold: 0, + max_playout_depth: 200, + max_iterations: 1, + max_time: Default::default(), + } + } +} diff --git a/src/strategies/mod.rs b/src/strategies/mod.rs index fffe8db..2557774 100644 --- a/src/strategies/mod.rs +++ b/src/strategies/mod.rs @@ -20,6 +20,14 @@ pub trait Search: Sync + Send { } fn set_friendly_name(&mut self, name: &str); + + #[allow(unused_variables)] + fn make_book_entry( + &mut self, + state: &::S, + ) -> (Vec<::A>, Vec) { + unimplemented!(); + } } #[cfg(test)] From bc2226c2f16cd2c9677ccd605970faa57093c0bc Mon Sep 17 00:00:00 2001 From: Thomas Marsh Date: Tue, 20 Feb 2024 11:28:25 -0500 Subject: [PATCH 2/2] Fix QBF and expand0 bug --- demo/book.rs | 2 + src/strategies/mcts/mod.rs | 15 ++++++-- src/strategies/mcts/select.rs | 70 ++++++++++++++++++----------------- src/strategies/mod.rs | 7 ++-- 4 files changed, 54 insertions(+), 40 deletions(-) diff --git a/demo/book.rs b/demo/book.rs index 4709804..91ec947 100644 --- a/demo/book.rs +++ b/demo/book.rs @@ -18,6 +18,7 @@ use mcts::games::druid::State; // QBF Config const NUM_THREADS: usize = 8; const NUM_GAMES: usize = 18000; +const EPSILON: f64 = 0.5; // MCTS Config const PLAYOUT_DEPTH: usize = 200; @@ -63,6 +64,7 @@ fn make_mcts() -> TreeSearch { fn make_qbf(book: OpeningBook) -> TreeSearch { // This is a little crazy. TreeSearch::default().config(SearchConfig::default().select(select::EpsilonGreedy { + epsilon: EPSILON, inner: select::QuasiBestFirst { book, search: make_mcts(), diff --git a/src/strategies/mcts/mod.rs b/src/strategies/mcts/mod.rs index fe3969b..9c90800 100644 --- a/src/strategies/mcts/mod.rs +++ b/src/strategies/mcts/mod.rs @@ -282,7 +282,7 @@ where } else { let mut actions = Vec::new(); G::generate_actions(state, &mut actions); - assert!(!actions.is_empty()); + debug_assert!(!actions.is_empty()); node.state = NodeState::Expanded { children: vec![None; actions.len()], actions, @@ -294,6 +294,7 @@ where #[inline] pub fn select(&mut self, ctx: &mut SearchContext) { let player = G::player_to_move(&ctx.state); + debug_assert!(self.stack.is_empty()); loop { self.stack.push(ctx.current_id); @@ -351,9 +352,9 @@ where ctx.traverse(child_id); ctx.state = state; - self.stack.push(ctx.current_id); if self.config.expand_threshold > 0 { + self.stack.push(ctx.current_id); return; } } @@ -483,9 +484,13 @@ where ) } - fn reset(&mut self) -> Id { - self.index.clear(); + pub(crate) fn reset_iter(&mut self) { self.stack.clear(); + self.trial = None; + } + + pub(crate) fn reset(&mut self) -> Id { + self.index.clear(); self.stats.accum_depth = 0; self.stats.iter_count = 0; self.new_root() @@ -546,7 +551,9 @@ where if self.timer.done() { break; } + self.reset_iter(); let mut ctx = SearchContext::new(root_id, state.clone()); + self.select(&mut ctx); self.trial = Some(self.simulate(&ctx.state, G::player_to_move(state).to_index())); self.backprop(&mut ctx, G::player_to_move(state).to_index()); diff --git a/src/strategies/mcts/select.rs b/src/strategies/mcts/select.rs index f217719..e50e4f6 100644 --- a/src/strategies/mcts/select.rs +++ b/src/strategies/mcts/select.rs @@ -5,6 +5,7 @@ use rand::rngs::SmallRng; use super::*; use crate::game::Game; use crate::strategies::Search; +use crate::util::random_best; pub struct SelectContext<'a, G: Game> { pub q_init: node::UnvisitedValueEstimate, @@ -89,6 +90,7 @@ where } fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, aux: Self::Aux) -> Self::Score { + println!("greedy: score_child"); self.inner.score_child(ctx, child_id, aux) } @@ -787,51 +789,53 @@ where fn best_child(&mut self, ctx: &SelectContext<'_, G>, rng: &mut SmallRng) -> usize { let current = ctx.index.get(ctx.current_id); - let best = random_best_index(current.children(), self, ctx, rng); - let children = current.children(); - let score = children[best].map(|child_id| self.score_child(ctx, child_id, ())); + let available = current.actions(); - if score.is_none() { - let action = self.search.choose_action(ctx.state); - current.actions().iter().position(|a| *a == action).unwrap() - } else { - best - } - } - - #[inline(always)] - fn setup(&mut self, ctx: &SelectContext<'_, G>) -> Self::Aux { - self.key_init = ctx + let key_init = ctx .stack .iter() .skip(1) .map(|id| ctx.index.get(*id).action(ctx.index)) - .collect(); - } - - #[inline(always)] - fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, _: Self::Aux) -> f64 { - let child = ctx.index.get(child_id); - let action = child.action(ctx.index); - let mut key = self.key_init.clone(); - key.push(action.clone()); + .collect::>(); let k_score = self.k[ctx.player_to_move]; - let score = self - .book - .score(key.as_slice(), ctx.player_to_move) - .unwrap_or(f64::NEG_INFINITY); - if score > k_score { - score + + let enumerated = available.iter().cloned().enumerate().collect::>(); + let best = random_best(enumerated.as_slice(), rng, |(_, action): &(usize, G::A)| { + let mut key = key_init.clone(); + key.push(action.clone()); + + let score = self + .book + .score(key.as_slice(), ctx.player_to_move) + .unwrap_or(f64::NEG_INFINITY); + if score > k_score { + score + } else { + // NOTE: we depend on random_best using this value internally + // as an equivalence for None types + f64::NEG_INFINITY + } + }); + + if let Some((best_index, _)) = best { + *best_index } else { - // NOTE: we depend on random_best using this value internally - // as an equivalence for None types - f64::NEG_INFINITY + let action = self.search.choose_action(ctx.state); + available.iter().position(|p| *p == action.clone()).unwrap() } } + #[inline(always)] + fn setup(&mut self, _: &SelectContext<'_, G>) -> Self::Aux {} + + #[inline(always)] + fn score_child(&self, _: &SelectContext<'_, G>, _: Id, _: Self::Aux) -> f64 { + 0. + } + #[inline(always)] fn unvisited_value(&self, _: &SelectContext<'_, G>, _: Self::Aux) -> f64 { - f64::NEG_INFINITY + 0. } } diff --git a/src/strategies/mod.rs b/src/strategies/mod.rs index 2557774..b93badf 100644 --- a/src/strategies/mod.rs +++ b/src/strategies/mod.rs @@ -106,10 +106,10 @@ mod tests { ); // Construct new root - let root_id = ts.new_root(); - + let root_id = ts.reset(); // Helper step function let step = |ts: &mut TS| { + ts.reset_iter(); let mut ctx = mcts::SearchContext::new(root_id, init_state); ts.select(&mut ctx); let trial = ts.simulate(&ctx.state, G::player_to_move(&init_state).to_index()); @@ -120,7 +120,8 @@ mod tests { "relevant utility: {:?}", G::compute_utilities(&trial.state)[G::player_to_move(&init_state).to_index()] ); - ts.backprop(&mut ctx, trial, G::player_to_move(&init_state).to_index()); + ts.trial = Some(trial); + ts.backprop(&mut ctx, G::player_to_move(&init_state).to_index()); ctx.current_id };