Skip to content

Commit

Permalink
Transposition table support
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarsh committed Feb 28, 2024
1 parent dd1576a commit 158d127
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 97 deletions.
27 changes: 11 additions & 16 deletions demo/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use rand_core::SeedableRng;
const ROUNDS: usize = 20;
const PLAYOUT_DEPTH: usize = 200;
const C_TUNED: f64 = 1.625;
const MAX_ITER: usize = 10_000_000;
const MAX_ITER: usize = 10_000;
const EXPAND_THRESHOLD: u32 = 1;
const VERBOSE: bool = false;
const MAX_TIME_SECS: u64 = 0;
Expand All @@ -38,18 +38,16 @@ struct Args {
#[arg(long)]
seed: u64,

#[arg(long)]
threshold: u32,

#[arg(long)]
bias: f64,
// #[arg(long)]
// threshold: u32,

// #[arg(long)]
// bias: f64,
#[arg(long)]
c: f64,

#[arg(long)]
epsilon: f64,

// #[arg(long)]
// epsilon: f64,
#[arg(long)]
q_init: String,
}
Expand Down Expand Up @@ -123,7 +121,7 @@ fn parse_q_init(s: &str) -> Option<UnvisitedValueEstimate> {
}
}

fn make_candidate(args: Args) -> TS<util::Ucb1GraveMast> {
fn make_candidate(args: Args) -> TS<util::Ucb1> {
TS::default()
.config(
SearchConfig::default()
Expand All @@ -132,13 +130,10 @@ fn make_candidate(args: Args) -> TS<util::Ucb1GraveMast> {
.max_time(Duration::from_secs(MAX_TIME_SECS))
.expand_threshold(EXPAND_THRESHOLD)
.q_init(parse_q_init(args.q_init.as_str()).unwrap())
.select(select::Ucb1Grave {
.use_transpositions(true)
.select(select::Ucb1 {
exploration_constant: args.c,
threshold: args.threshold,
bias: args.bias,
current_ref_id: None,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(args.epsilon)),
}),
)
.verbose(VERBOSE)
.rng(SmallRng::seed_from_u64(args.seed))
Expand Down
114 changes: 105 additions & 9 deletions demo/playground.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use mcts::game::Game;
use mcts::games::nim;
use mcts::games::ttt;
use mcts::strategies::flat_mc::FlatMonteCarloStrategy;
use mcts::strategies::mcts::node::UnvisitedValueEstimate;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;
use mcts::strategies::mcts::util;
Expand All @@ -25,23 +26,117 @@ type NimFlatMC = FlatMonteCarloStrategy<Nim>;
type NimMCTS = TreeSearch<Nim, util::Ucb1>;
type TttMCTS = TreeSearch<TicTacToe, util::Ucb1>;

fn ucd() {
use mcts::games::traffic_lights::TrafficLights;

type Uct = TreeSearch<TrafficLights, util::Ucb1>;
let uct = Uct::default()
.config(
SearchConfig::default()
.max_iterations(10_000)
.q_init(mcts::strategies::mcts::node::UnvisitedValueEstimate::Parent)
.expand_threshold(1)
.select(select::Ucb1 {
exploration_constant: 2.0f64.sqrt(),
}),
)
.verbose(false);

type Ucd = TreeSearch<TrafficLights, util::Ucb1>;
let mut ucd = Ucd::default()
.config(
SearchConfig::default()
.max_iterations(10_000)
.q_init(mcts::strategies::mcts::node::UnvisitedValueEstimate::Parent)
.expand_threshold(1)
.use_transpositions(true)
.q_init(UnvisitedValueEstimate::Infinity)
.select(select::Ucb1 {
exploration_constant: 0.01f64.sqrt(),
}),
)
.verbose(false);
ucd.set_friendly_name("mcts[ucb1]+ucd");

let mast: TreeSearch<TrafficLights, util::Ucb1Mast> = TreeSearch::default()
.config(
SearchConfig::default()
.expand_threshold(1)
.max_iterations(10_000)
.select(select::Ucb1 {
exploration_constant: 1.86169408634305,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.10750788170844316)),
)
.verbose(false);

let mut mast_ucd: TreeSearch<TrafficLights, util::Ucb1Mast> = TreeSearch::default()
.config(
SearchConfig::default()
.expand_threshold(1)
.max_iterations(10_000)
.use_transpositions(true)
.q_init(UnvisitedValueEstimate::Infinity)
.select(select::Ucb1 {
exploration_constant: 0.01,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.10750788170844316)),
)
.verbose(false);
mast_ucd.set_friendly_name("mcts[ucb1_mast]+ucd");

let tuned: TreeSearch<TrafficLights, util::Ucb1Tuned> = TreeSearch::default().config(
SearchConfig::default()
.expand_threshold(1)
.max_iterations(10_000)
.select(select::Ucb1Tuned {
exploration_constant: 1.8617,
}),
);

let mut tuned_ucd: TreeSearch<TrafficLights, util::Ucb1Tuned> = TreeSearch::default().config(
SearchConfig::default()
.expand_threshold(1)
.max_iterations(10_000)
.use_transpositions(true)
.select(select::Ucb1Tuned {
exploration_constant: 1.8617,
}),
);
tuned_ucd.set_friendly_name("mcts[ucb1_tuned]+ucd");

let mut strats = vec![
AnySearch::new(uct),
AnySearch::new(ucd),
AnySearch::new(mast),
AnySearch::new(mast_ucd),
AnySearch::new(tuned),
AnySearch::new(tuned_ucd),
];

_ = round_robin_multiple::<TrafficLights, AnySearch<_>>(
&mut strats,
1000,
&Default::default(),
mcts::util::Verbosity::Verbose,
);
}

fn traffic_lights() {
use mcts::games::traffic_lights::TrafficLights;

type TS = TreeSearch<TrafficLights, util::Ucb1GraveMast>;
type TS = TreeSearch<TrafficLights, util::Ucb1>;
let ts = TS::default()
.config(
SearchConfig::default()
.max_iterations(10_000_000)
.max_iterations(10_000)
.q_init(mcts::strategies::mcts::node::UnvisitedValueEstimate::Parent)
.expand_threshold(0)
.select(select::Ucb1Grave {
exploration_constant: 2.0f64.sqrt(),
threshold: 1000,
bias: 764.,
current_ref_id: None,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.1865)),
.use_transpositions(true)
.q_init(UnvisitedValueEstimate::Infinity)
.select(select::Ucb1 {
exploration_constant: 0.001,
}),
)
.verbose(true);

Expand Down Expand Up @@ -356,6 +451,7 @@ fn main() {
pretty_env_logger::init();

traffic_lights();
ucd();
knightthrough();
breakthrough();
gonnect();
Expand Down
9 changes: 5 additions & 4 deletions scripts/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ class GameSearch:
def configspace(self) -> ConfigurationSpace:
cs = ConfigurationSpace(seed=0)
c = Float('c', (0, 3), default=math.sqrt(2))
bias = Float('bias', (0, 1000), default=10e-6)
threshold = Integer('threshold', (0, 1000), default=100)
epsilon = Float('epsilon', (0, 1), default=0.1)
#bias = Float('bias', (0, 1000), default=10e-6)
#threshold = Integer('threshold', (0, 1000), default=100)
#epsilon = Float('epsilon', (0, 1), default=0.1)
q_init = Categorical("q-init", ["Draw", "Infinity", "Loss", "Parent", "Win"])
cs.add_hyperparameters([c, bias, threshold, epsilon, q_init])
#cs.add_hyperparameters([c, bias, threshold, epsilon, q_init])
cs.add_hyperparameters([c, q_init])
return cs

def train(self) -> str:
Expand Down
30 changes: 9 additions & 21 deletions src/games/traffic_lights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,7 @@ impl RectangularBoard for HashedPosition {

impl Display for HashedPosition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
RectangularBoardDisplay(self).fmt(f)?;

use super::ttt::sym;

write!(f, "-----------------------------------")?;
let mut bs = [0; 8];
sym::board_symmetries(self.position.board, &mut bs);
for b in bs {
RectangularBoardDisplay(&HashedPosition {
position: Position {
board: b,
..Default::default()
},
..Default::default()
})
.fmt(f)?;

writeln!(f)?;
}
write!(f, "-----------------------------------")?;
Ok(())
RectangularBoardDisplay(self).fmt(f)
}
}

Expand All @@ -236,6 +216,14 @@ impl Game for TrafficLights {
tmp
}

fn get_reward(init: &Self::S, term: &Self::S) -> f64 {
let utility = Self::compute_utilities(term)[Self::player_to_move(init).to_index()];
if utility < 0. {
return utility * 100.;
}
utility
}

fn notation(_state: &Self::S, m: &Self::A) -> String {
let x = m.0 % 3;
let y = m.0 / 3;
Expand Down
6 changes: 6 additions & 0 deletions src/strategies/mcts/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ where
pub max_playout_depth: usize,
pub max_iterations: usize,
pub max_time: std::time::Duration,
pub use_transpositions: bool,
}

impl<G, S> SearchConfig<G, S>
Expand Down Expand Up @@ -114,4 +115,9 @@ where
self
}
}

pub fn use_transpositions(mut self, use_transpositions: bool) -> Self {
self.use_transpositions = use_transpositions;
self
}
}
57 changes: 57 additions & 0 deletions src/strategies/mcts/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::game::Action;

use rustc_hash::FxHashMap as HashMap;
use serde::Serialize;
use std::ops::Add;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::*;

Expand All @@ -12,6 +13,16 @@ pub struct ActionStats {
pub score: f64,
}

impl Add for ActionStats {
type Output = Self;
fn add(self, rhs: Self) -> Self {
ActionStats {
num_visits: self.num_visits + rhs.num_visits,
score: self.score + rhs.score,
}
}
}

#[derive(Debug, Serialize)]
pub struct NodeStats<A: Action> {
pub num_visits: u32,
Expand Down Expand Up @@ -105,6 +116,52 @@ impl<A: Action> NodeStats<A> {
}
}

#[inline]
fn merge_grave_maps<A: Action>(
a: &HashMap<A, ActionStats>,
b: &HashMap<A, ActionStats>,
) -> HashMap<A, ActionStats> {
let mut a = a.clone();
for (key, value) in b {
match a.get_mut(key) {
Some(existing_value) => {
*existing_value = existing_value.clone() + value.clone();
}
None => {
a.insert(key.clone(), value.clone());
}
}
}
a
}

impl<A: Action> Add for NodeStats<A> {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
NodeStats {
num_visits: self.num_visits + rhs.num_visits,
num_visits_virtual: AtomicU32::new(
self.num_visits_virtual.load(Relaxed) + rhs.num_visits_virtual.load(Relaxed),
),
scores: self
.scores
.into_iter()
.zip(rhs.scores)
.map(|(x, y)| x + y)
.collect(),
sum_squared_scores: self
.sum_squared_scores
.into_iter()
.zip(rhs.sum_squared_scores)
.map(|(x, y)| x + y)
.collect(),
scalar_amaf: self.scalar_amaf + rhs.scalar_amaf,
grave_stats: merge_grave_maps(&self.grave_stats, &rhs.grave_stats),
}
}
}

// QInit:
// - MC-GRAVE: Infinity
// - MC-BRAVE: Infinity
Expand Down
6 changes: 6 additions & 0 deletions src/strategies/mcts/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ where
player_to_move: G::player_to_move(&ctx.state).to_index(),
state: &ctx.state,
index: &self.index,
table: &self.table,
use_transpositions: self.config.use_transpositions,
};
self.config.select.best_child(&select_ctx, &mut self.rng)
};
Expand Down Expand Up @@ -258,6 +260,8 @@ where
player_to_move: G::player_to_move(state).to_index(),
state,
index: &self.index,
table: &self.table,
use_transpositions: self.config.use_transpositions,
},
&mut self.rng,
);
Expand Down Expand Up @@ -400,6 +404,8 @@ where
state: &state,
player_to_move: player.to_index(),
index: &self.index,
table: &self.table,
use_transpositions: self.config.use_transpositions,
};
let best_idx = self
.config
Expand Down
Loading

0 comments on commit 158d127

Please sign in to comment.