Skip to content

Commit

Permalink
Add simple graphviz rendering and work on traffic lights
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarsh committed Feb 25, 2024
1 parent d188e9c commit e50aaf1
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 58 deletions.
9 changes: 1 addition & 8 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,4 @@ Cargo.lock
__pycache__
smac3_output
/profile.json



# Added by cargo
#
# already existing elements were commented out

#/target
.DS_Store
52 changes: 45 additions & 7 deletions demo/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::time::Duration;

use clap::Parser;

use mcts::games::atarigo::AtariGo;
use mcts::games::atarigo::State;
use mcts::game::Game;
use mcts::strategies::mcts::node::UnvisitedValueEstimate;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;
use mcts::strategies::mcts::util;
Expand All @@ -21,12 +21,14 @@ use rand_core::SeedableRng;
const ROUNDS: usize = 20;
const PLAYOUT_DEPTH: usize = 200;
const C_TUNED: f64 = 1.625;
const MAX_ITER: usize = 10000;
const EXPAND_THRESHOLD: u32 = 5;
const MAX_ITER: usize = 10_000_000;
const EXPAND_THRESHOLD: u32 = 1;
const VERBOSE: bool = false;
const MAX_TIME_SECS: u64 = 0;

type G = AtariGo<8>;
use mcts::games::ttt_traffic_lights;

type G = ttt_traffic_lights::TttTrafficLights;

type TS<S> = TreeSearch<G, S>;

Expand All @@ -47,6 +49,9 @@ struct Args {

#[arg(long)]
epsilon: f64,

#[arg(long)]
q_init: String,
}

fn main() {
Expand All @@ -58,7 +63,7 @@ fn main() {
let results = round_robin_multiple::<G, AnySearch<'_, G>>(
&mut strategies,
ROUNDS,
&State::default(),
&<G as Game>::S::default(),
Verbosity::Silent,
);
let cost = calc_cost(results);
Expand All @@ -70,7 +75,28 @@ fn calc_cost(results: Vec<mcts::util::Result>) -> f64 {
1.0 - w / (ROUNDS * 2) as f64
}

fn make_opponent(seed: u64) -> TS<util::Ucb1> {
fn make_opponent(seed: u64) -> TS<util::Ucb1GraveMast> {
TS::default()
.config(
SearchConfig::default()
.max_iterations(MAX_ITER)
.max_playout_depth(PLAYOUT_DEPTH)
.max_time(Duration::from_secs(MAX_TIME_SECS))
.expand_threshold(EXPAND_THRESHOLD)
.q_init(UnvisitedValueEstimate::Parent)
.select(select::Ucb1Grave {
exploration_constant: 0.69535,
threshold: 285,
bias: 628.,
current_ref_id: None,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.0015)),
)
.verbose(VERBOSE)
.rng(SmallRng::seed_from_u64(seed))
}

fn make_opponent_(seed: u64) -> TS<util::Ucb1> {
TS::default()
.config(
SearchConfig::default()
Expand All @@ -86,6 +112,17 @@ fn make_opponent(seed: u64) -> TS<util::Ucb1> {
.rng(SmallRng::seed_from_u64(seed))
}

fn parse_q_init(s: &str) -> Option<UnvisitedValueEstimate> {
match s {
"Draw" => Some(UnvisitedValueEstimate::Draw),
"Infinity" => Some(UnvisitedValueEstimate::Infinity),
"Loss" => Some(UnvisitedValueEstimate::Loss),
"Parent" => Some(UnvisitedValueEstimate::Parent),
"Win" => Some(UnvisitedValueEstimate::Win),
_ => None,
}
}

fn make_candidate(args: Args) -> TS<util::Ucb1GraveMast> {
TS::default()
.config(
Expand All @@ -94,6 +131,7 @@ fn make_candidate(args: Args) -> TS<util::Ucb1GraveMast> {
.max_playout_depth(PLAYOUT_DEPTH)
.max_time(Duration::from_secs(MAX_TIME_SECS))
.expand_threshold(EXPAND_THRESHOLD)
.q_init(parse_q_init(args.q_init.as_str()).unwrap())
.select(select::Ucb1Grave {
exploration_constant: args.c,
threshold: args.threshold,
Expand Down
66 changes: 43 additions & 23 deletions demo/playground.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use mcts::game::Game;
use mcts::games::nim;
use mcts::games::ttt;
use mcts::strategies::flat_mc::FlatMonteCarloStrategy;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;
use mcts::strategies::mcts::util;
use mcts::strategies::mcts::SearchConfig;
use mcts::strategies::mcts::TreeSearch;
Expand All @@ -23,36 +25,44 @@ type NimFlatMC = FlatMonteCarloStrategy<Nim>;
type NimMCTS = TreeSearch<Nim, util::Ucb1>;
type TttMCTS = TreeSearch<TicTacToe, util::Ucb1>;

fn summarize(label_a: &str, label_b: &str, results: Vec<Option<usize>>) {
let (win_a, win_b, draw) = results.iter().fold((0, 0, 0), |(a, b, c), x| match x {
Some(0) => (a + 1, b, c),
Some(1) => (a, b + 1, c),
None => (a, b, c + 1),
_ => (a, b, c),
});
let total = (win_a + win_b + draw) as f32;
let pct_a = win_a as f32 / total * 100.;
let pct_b = win_b as f32 / total * 100.;
println!("{label_a} / {label_b}: {win_a} ({pct_a:.2}%) / {win_b} ({pct_b:.2}%) [{draw} draws]");
fn traffic_lights() {
use mcts::games::ttt_traffic_lights::TttTrafficLights;

type TS = TreeSearch<TttTrafficLights, util::Ucb1GraveMast>;
let ts = TS::default()
.config(
SearchConfig::default()
.max_iterations(10_000_000)
.q_init(mcts::strategies::mcts::node::UnvisitedValueEstimate::Parent)
.expand_threshold(0)
.select(select::Ucb1Grave {
exploration_constant: 2.0f64.sqrt(),
threshold: 1000,
bias: 764.,
current_ref_id: None,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.1865)),
)
.verbose(true);

self_play(ts);
}

fn knightthrough() {
use mcts::games::knightthrough::Knightthrough;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;

type TS = TreeSearch<Knightthrough<8, 8>, util::Ucb1GraveMast>;
let ts = TS::default()
.config(
SearchConfig::default()
.max_time(Duration::from_secs(10))
.max_iterations(20000)
.select(select::Ucb1Grave {
exploration_constant: 1.32562,
threshold: 720,
bias: 430.36,
exploration_constant: 2.12652,
threshold: 131,
bias: 68.65,
current_ref_id: None,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.98)),
.simulate(simulate::EpsilonGreedy::with_epsilon(0.12)),
)
.verbose(true);

Expand All @@ -61,8 +71,6 @@ fn knightthrough() {

fn breakthrough() {
use mcts::games::breakthrough::Breakthrough;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;

type TS = TreeSearch<Breakthrough<6, 4>, util::Ucb1GraveMast>;
let ts = TS::default()
Expand All @@ -84,8 +92,6 @@ fn breakthrough() {

fn atarigo() {
use mcts::games::atarigo::AtariGo;
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::simulate;

type TS = TreeSearch<AtariGo<5>, util::Ucb1GraveMast>;
let ts = TS::default()
Expand All @@ -107,7 +113,6 @@ fn atarigo() {

fn gonnect() {
use mcts::games::gonnect::Gonnect;
use mcts::strategies::mcts::select;

type TS = TreeSearch<Gonnect<7>, util::Ucb1Grave>;
let ts = TS::default()
Expand All @@ -130,6 +135,7 @@ fn gonnect() {

fn expansion_test() {
use mcts::games::bid_ttt as ttt;

type TS = TreeSearch<ttt::BiddingTicTacToe, util::Ucb1>;

let expand5 = TS::default()
Expand Down Expand Up @@ -178,6 +184,19 @@ fn ucb_test() {
);
}

fn summarize(label_a: &str, label_b: &str, results: Vec<Option<usize>>) {
let (win_a, win_b, draw) = results.iter().fold((0, 0, 0), |(a, b, c), x| match x {
Some(0) => (a + 1, b, c),
Some(1) => (a, b + 1, c),
None => (a, b, c + 1),
_ => (a, b, c),
});
let total = (win_a + win_b + draw) as f32;
let pct_a = win_a as f32 / total * 100.;
let pct_b = win_b as f32 / total * 100.;
println!("{label_a} / {label_b}: {win_a} ({pct_a:.2}%) / {win_b} ({pct_b:.2}%) [{draw} draws]");
}

struct BattleConfig {
num_samples: usize, // number of games to play
samples_per_move: Vec<usize>,
Expand Down Expand Up @@ -336,6 +355,7 @@ fn main() {
color_backtrace::install();
pretty_env_logger::init();

traffic_lights();
knightthrough();
breakthrough();
gonnect();
Expand Down
7 changes: 4 additions & 3 deletions scripts/hyper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ConfigSpace import ConfigurationSpace
from smac import BlackBoxFacade, Scenario
from pathlib import Path
from ConfigSpace import Float, Integer
from ConfigSpace import Categorical, Float, Integer

import math
import os
Expand All @@ -21,7 +21,8 @@ def configspace(self) -> ConfigurationSpace:
bias = Float('bias', (0, 1000), default=10e-6)
threshold = Integer('threshold', (0, 1000), default=100)
epsilon = Float('epsilon', (0, 1), default=0.1)
cs.add_hyperparameters([c, bias, threshold, epsilon])
q_init = Categorical("q-init", ["Draw", "Infinity", "Loss", "Parent", "Win"])
cs.add_hyperparameters([c, bias, threshold, epsilon, q_init])
return cs

def train(self) -> str:
Expand All @@ -35,7 +36,7 @@ def train(self) -> str:
scenario = Scenario(
model.configspace,
deterministic=True,
n_trials=30,
n_trials=100,
n_workers=(os.cpu_count() // 2))

# Now we use SMAC to find the best hyperparameters
Expand Down
23 changes: 22 additions & 1 deletion src/games/gonnect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,35 @@ impl<const N: usize> fmt::Display for MovesDisplay<N> {
}
}

#[cfg(test)]
impl<const N: usize> crate::strategies::mcts::render::NodeRender for State<N> {}

#[cfg(test)]
mod tests {
use crate::util::random_play;
use crate::{
strategies::{
mcts::{render, util, SearchConfig, TreeSearch},
Search,
},
util::random_play,
};

use super::*;

#[test]
fn test_gonnect() {
random_play::<Gonnect<3>>();
}

#[test]
fn test_render() {
let mut search = TreeSearch::<Gonnect<8>, util::Ucb1>::default().config(
SearchConfig::default()
.expand_threshold(1)
.q_init(crate::strategies::mcts::node::UnvisitedValueEstimate::Draw)
.max_iterations(50000),
);
_ = search.choose_action(&State::default());
render::render(&search);
}
}
1 change: 1 addition & 0 deletions src/games/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod gonnect;
pub mod knightthrough;
pub mod nim;
pub mod null;
pub mod shibumi;
pub mod ttt;
pub mod ttt_traffic_lights;
pub mod unit;
Expand Down
Loading

0 comments on commit e50aaf1

Please sign in to comment.