Skip to content

Commit

Permalink
Correct AMAF logic
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarsh committed Feb 29, 2024
1 parent 631ef6c commit db3a674
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 51 deletions.
14 changes: 6 additions & 8 deletions demo/druid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,17 @@ fn main() {
)
.verbose(VERBOSE);

let mut amaf: TreeSearch<Druid, strategy::ScalarAmaf> = TreeSearch::default()
.config(base_config().select(select::ScalarAmaf {
bias: BIAS,
exploration_constant: C_LOW,
let mut amaf: TreeSearch<Druid, strategy::Amaf> = TreeSearch::default()
.config(base_config().select(select::Amaf {
exploration_constant: C_TUNED,
}))
.verbose(VERBOSE);

let mut amaf_mast: TreeSearch<Druid, strategy::ScalarAmafMast> = TreeSearch::default()
let mut amaf_mast: TreeSearch<Druid, strategy::AmafMast> = TreeSearch::default()
.config(
base_config()
.select(select::ScalarAmaf {
bias: BIAS,
exploration_constant: C_LOW,
.select(select::Amaf {
exploration_constant: C_TUNED,
})
.simulate(simulate::EpsilonGreedy::with_epsilon(0.1)),
)
Expand Down
25 changes: 13 additions & 12 deletions src/strategies/mcts/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::index::Id;
use super::*;
use crate::game::{Game, PlayerIndex};

use rustc_hash::FxHashMap as HashMap;
use rustc_hash::FxHashMap;

pub trait BackpropStrategy: Clone + Sync + Send {
// TODO: cleanup the arguments to this, or just move it to TreeSearch
Expand Down Expand Up @@ -38,24 +38,25 @@ pub trait BackpropStrategy: Clone + Sync + Send {
// update: AMAF
if flags.amaf() {
let node = index.get(node_id);

if node.is_expanded() {
let child_actions: HashMap<_, _> = node
if !node.is_root() {
let parent_id = node.parent_id;
assert!(!stack.is_empty());
assert_eq!(parent_id, *stack.last().unwrap());
let parent = index.get(parent_id);
let sibling_actions: FxHashMap<_, _> = parent
.actions()
.iter()
.cloned()
.zip(node.children().iter().cloned().flatten())
.zip(parent.children().iter().cloned().flatten())
.collect();

for action in &trial.actions {
if let Some(child_id) = child_actions.get(action) {
if let Some(child_id) = sibling_actions.get(action) {
let child = index.get_mut(*child_id);
child.stats.scalar_amaf.num_visits += 1;

// TODO: I'm not convinced which is the right update strategy for this one
// child.stats.scalar_amaf.score += utilities[player];
child.stats.scalar_amaf.score +=
utilities[G::player_to_move(&ctx.state).to_index()];
(0..G::num_players()).for_each(|i| {
child.stats.amaf[i].num_visits += 1;
child.stats.amaf[i].score += utilities[i];
})
}
}
}
Expand Down
16 changes: 10 additions & 6 deletions src/strategies/mcts/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ pub struct NodeStats<A: Action> {
// Only used for UCB1Tuned; how to parameterize
pub sum_squared_scores: Vec<f64>,

// TODO: what is this actually? I can't find this in the literature, but it's
// like a coarser version of RAVE/AMAF.
pub scalar_amaf: ActionStats,
pub amaf: Vec<ActionStats>,

// TODO: Only used for GRAVE; how to parameterize
#[serde(skip_serializing)]
Expand All @@ -51,7 +49,7 @@ impl<A: Action> Clone for NodeStats<A> {
num_visits_virtual: AtomicU32::new(self.num_visits_virtual.load(Relaxed)),
scores: self.scores.clone(),
sum_squared_scores: self.sum_squared_scores.clone(),
scalar_amaf: self.scalar_amaf.clone(),
amaf: self.amaf.clone(),
grave_stats: self.grave_stats.clone(),
}
}
Expand All @@ -64,7 +62,7 @@ impl<A: Action> NodeStats<A> {
num_visits_virtual: AtomicU32::new(0),
scores: vec![0.; num_players],
sum_squared_scores: vec![0.; num_players],
scalar_amaf: Default::default(),
amaf: vec![ActionStats::default(); num_players],
grave_stats: Default::default(),
}
}
Expand Down Expand Up @@ -125,6 +123,7 @@ impl<A: Action> Add for NodeStats<A> {
num_visits_virtual: AtomicU32::new(
self.num_visits_virtual.load(Relaxed) + rhs.num_visits_virtual.load(Relaxed),
),
// TODO: group per-player stats to avoid N*M loops
scores: self
.scores
.into_iter()
Expand All @@ -137,7 +136,12 @@ impl<A: Action> Add for NodeStats<A> {
.zip(rhs.sum_squared_scores)
.map(|(x, y)| x + y)
.collect(),
scalar_amaf: self.scalar_amaf + rhs.scalar_amaf,
amaf: self
.amaf
.into_iter()
.zip(rhs.amaf)
.map(|(x, y)| x + y)
.collect(),
// NOTE: GRAVE is not currently supported with transpositions
grave_stats: HashMap::default(),
}
Expand Down
26 changes: 11 additions & 15 deletions src/strategies/mcts/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,21 +598,19 @@ impl<G: Game> SelectStrategy<G> for McBrave {
// This one was found in some implementations of RAVE. It seems strong, but I
// can't find references to it in the literature.
#[derive(Clone)]
pub struct ScalarAmaf {
pub struct Amaf {
pub exploration_constant: f64,
pub bias: f64,
}

impl Default for ScalarAmaf {
impl Default for Amaf {
fn default() -> Self {
Self {
exploration_constant: 2f64.sqrt(),
bias: 700.0,
}
}
}

impl<G: Game> SelectStrategy<G> for ScalarAmaf {
impl<G: Game> SelectStrategy<G> for Amaf {
type Score = f64;
type Aux = f64;

Expand All @@ -624,20 +622,18 @@ impl<G: Game> SelectStrategy<G> for ScalarAmaf {
#[inline(always)]
fn score_child(&self, ctx: &SelectContext<'_, G>, child_id: Id, parent_log: f64) -> f64 {
let stats = ctx.child_stats(child_id);
let amaf_n = stats.amaf[ctx.player].num_visits;
let amaf_q = stats.amaf[ctx.player].score;

let avg_amaf_score = amaf_q / amaf_n as f64;
let exploit = stats.exploitation_score(ctx.player);

let num_visits = stats.num_visits + stats.num_visits_virtual.load(Relaxed);
let explore = (parent_log / num_visits as f64).sqrt();
let uct_value = exploit + self.exploration_constant * explore;

let amaf_value = if num_visits > 0 {
stats.scalar_amaf.score / stats.num_visits as f64
} else {
0.
};

let beta = self.bias / (self.bias + num_visits as f64);

(1. - beta) * uct_value + beta * amaf_value
exploit
+ self.exploration_constant * explore
+ avg_amaf_score * (explore / 1.max(amaf_n) as f64)
}

#[inline(always)]
Expand Down
20 changes: 10 additions & 10 deletions src/strategies/mcts/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,20 @@ impl<G: Game> Default for SearchConfig<G, Ucb1Mast> {
}

#[derive(Clone)]
pub struct ScalarAmaf;
pub struct Amaf;

impl<G: Game> Strategy<G> for ScalarAmaf {
type Select = select::ScalarAmaf;
impl<G: Game> Strategy<G> for Amaf {
type Select = select::Amaf;
type Simulate = simulate::Uniform;
type Backprop = backprop::Classic;
type FinalAction = select::RobustChild;

fn friendly_name() -> String {
"scalar_amaf".into()
"amaf".into()
}
}

impl<G: Game> Default for SearchConfig<G, ScalarAmaf> {
impl<G: Game> Default for SearchConfig<G, Amaf> {
fn default() -> Self {
Self {
select: Default::default(),
Expand All @@ -136,20 +136,20 @@ impl<G: Game> Default for SearchConfig<G, ScalarAmaf> {
}

#[derive(Clone)]
pub struct ScalarAmafMast;
pub struct AmafMast;

impl<G: Game> Strategy<G> for ScalarAmafMast {
type Select = select::ScalarAmaf;
impl<G: Game> Strategy<G> for AmafMast {
type Select = select::Amaf;
type Simulate = simulate::EpsilonGreedy<G, simulate::Mast>;
type Backprop = backprop::Classic;
type FinalAction = select::RobustChild;

fn friendly_name() -> String {
"scalar_amaf+mast".into()
"amaf+mast".into()
}
}

impl<G: Game> Default for SearchConfig<G, ScalarAmafMast> {
impl<G: Game> Default for SearchConfig<G, AmafMast> {
fn default() -> Self {
Self {
select: Default::default(),
Expand Down

0 comments on commit db3a674

Please sign in to comment.