From ffe263b9b92d248f4f8f23855953e02167311a61 Mon Sep 17 00:00:00 2001 From: Sugawara <47840708+Asugawara@users.noreply.github.com> Date: Sat, 5 Dec 2020 16:40:20 -0700 Subject: [PATCH] Copybara import of the project: -- 10a55ff6d02d733f7d24c4bb1f5618c876e643d6 by Asugawara : add nfsp -- 7908ccc1572657fd487ade88b59a4080a38a0e7d by Asugawara : add Sonnet Linear Module -- 5671c9f42cd8fa023ea8591b63ea849fdb02670c by Asugawara : action_probs: LongTensor to Tensor -- b6b9d7d435ac6059adb6057038ec096454e936ef by Asugawara : remove image and progress COPYBARA_INTEGRATE_REVIEW=https://github.com/deepmind/open_spiel/pull/450 from Asugawara:nfsp_pytorch b6b9d7d435ac6059adb6057038ec096454e936ef PiperOrigin-RevId: 345889227 Change-Id: Ib5558b3e05f4cfe96c1a9854a6956100b03ee2d4 --- open_spiel/colabs/research_nfsp_tf_pt.ipynb | 307 ++++++++++++++++ open_spiel/python/CMakeLists.txt | 1 + open_spiel/python/pytorch/dqn.py | 46 ++- open_spiel/python/pytorch/nfsp.py | 342 ++++++++++++++++++ .../python/pytorch/nfsp_pytorch_test.py | 112 ++++++ 5 files changed, 803 insertions(+), 5 deletions(-) create mode 100644 open_spiel/colabs/research_nfsp_tf_pt.ipynb create mode 100644 open_spiel/python/pytorch/nfsp.py create mode 100644 open_spiel/python/pytorch/nfsp_pytorch_test.py diff --git a/open_spiel/colabs/research_nfsp_tf_pt.ipynb b/open_spiel/colabs/research_nfsp_tf_pt.ipynb new file mode 100644 index 0000000000..de95ca071e --- /dev/null +++ b/open_spiel/colabs/research_nfsp_tf_pt.ipynb @@ -0,0 +1,307 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from absl import logging\n", + "import tensorflow.compat.v1 as tf\n", + "\n", + "from open_spiel.python import policy\n", + "from open_spiel.python import rl_environment\n", + "from open_spiel.python.algorithms import exploitability\n", + "from open_spiel.python.algorithms import nfsp\n", + "from open_spiel.python.pytorch import nfsp as nfsp_pt\n", + "\n", + "class NFSPPolicies(policy.Policy):\n", + " \"\"\"Joint policy to be evaluated.\"\"\"\n", + "\n", + " def __init__(self, env, nfsp_policies, mode):\n", + " game = env.game\n", + " player_ids = [0, 1]\n", + " super(NFSPPolicies, self).__init__(game, player_ids)\n", + " self._policies = nfsp_policies\n", + " self._mode = mode\n", + " self._obs = {\"info_state\": [None, None], \"legal_actions\": [None, None]}\n", + "\n", + " def action_probabilities(self, state, player_id=None):\n", + " cur_player = state.current_player()\n", + " legal_actions = state.legal_actions(cur_player)\n", + "\n", + " self._obs[\"current_player\"] = cur_player\n", + " self._obs[\"info_state\"][cur_player] = (\n", + " state.information_state_tensor(cur_player))\n", + " self._obs[\"legal_actions\"][cur_player] = legal_actions\n", + "\n", + " info_state = rl_environment.TimeStep(\n", + " observations=self._obs, rewards=None, discounts=None, step_type=None)\n", + "\n", + " with self._policies[cur_player].temp_mode_as(self._mode):\n", + " p = self._policies[cur_player].step(info_state, is_evaluation=True).probs\n", + " prob_dict = {action: p[action] for action in legal_actions}\n", + " return prob_dict\n", + "\n", + "\n", + "def tf_main(game,\n", + " env_config,\n", + " num_train_episodes,\n", + " eval_every,\n", + " hidden_layers_sizes,\n", + " replay_buffer_capacity,\n", + " reservoir_buffer_capacity,\n", + " anticipatory_param):\n", + " env = rl_environment.Environment(game, **env_configs)\n", + " info_state_size = env.observation_spec()[\"info_state\"][0]\n", + " num_actions = env.action_spec()[\"num_actions\"]\n", + "\n", + " hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n", + " kwargs = {\n", + " \"replay_buffer_capacity\": replay_buffer_capacity,\n", + " \"epsilon_decay_duration\": num_train_episodes,\n", + " \"epsilon_start\": 0.06,\n", + " \"epsilon_end\": 0.001,\n", + " }\n", + " expl_list = []\n", + " with tf.Session() as sess:\n", + " # pylint: disable=g-complex-comprehension\n", + " agents = [\n", + " nfsp.NFSP(sess, idx, info_state_size, num_actions, hidden_layers_sizes,\n", + " reservoir_buffer_capacity, anticipatory_param,\n", + " **kwargs) for idx in range(num_players)\n", + " ]\n", + " expl_policies_avg = NFSPPolicies(env, agents, nfsp.MODE.average_policy)\n", + "\n", + " sess.run(tf.global_variables_initializer())\n", + " for ep in range(num_train_episodes):\n", + " if (ep + 1) % eval_every == 0:\n", + " losses = [agent.loss for agent in agents]\n", + " print(\"Losses: %s\" %losses)\n", + " expl = exploitability.exploitability(env.game, expl_policies_avg)\n", + " expl_list.append(expl)\n", + " print(\"[%s] Exploitability AVG %s\" %(ep + 1, expl))\n", + " print(\"_____________________________________________\")\n", + "\n", + " time_step = env.reset()\n", + " while not time_step.last():\n", + " player_id = time_step.observations[\"current_player\"]\n", + " agent_output = agents[player_id].step(time_step)\n", + " action_list = [agent_output.action]\n", + " time_step = env.step(action_list)\n", + "\n", + " # Episode is over, step all agents with final info state.\n", + " for agent in agents:\n", + " agent.step(time_step)\n", + " return expl_list\n", + " \n", + "def pt_main(game,\n", + " env_config,\n", + " num_train_episodes,\n", + " eval_every,\n", + " hidden_layers_sizes,\n", + " replay_buffer_capacity,\n", + " reservoir_buffer_capacity,\n", + " anticipatory_param):\n", + " env = rl_environment.Environment(game, **env_configs)\n", + " info_state_size = env.observation_spec()[\"info_state\"][0]\n", + " num_actions = env.action_spec()[\"num_actions\"]\n", + "\n", + " hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n", + " kwargs = {\n", + " \"replay_buffer_capacity\": replay_buffer_capacity,\n", + " \"epsilon_decay_duration\": num_train_episodes,\n", + " \"epsilon_start\": 0.06,\n", + " \"epsilon_end\": 0.001,\n", + " }\n", + " expl_list = []\n", + " agents = [\n", + " nfsp_pt.NFSP(idx, info_state_size, num_actions, hidden_layers_sizes,\n", + " reservoir_buffer_capacity, anticipatory_param,\n", + " **kwargs) for idx in range(num_players)\n", + " ]\n", + " expl_policies_avg = NFSPPolicies(env, agents, nfsp_pt.MODE.average_policy) \n", + " for ep in range(num_train_episodes):\n", + " if (ep + 1) % eval_every == 0:\n", + " losses = [agent.loss.item() for agent in agents]\n", + " print(\"Losses: %s\" %losses)\n", + " expl = exploitability.exploitability(env.game, expl_policies_avg)\n", + " expl_list.append(expl)\n", + " print(\"[%s] Exploitability AVG %s\" %(ep + 1, expl))\n", + " print(\"_____________________________________________\") \n", + " time_step = env.reset()\n", + " while not time_step.last():\n", + " player_id = time_step.observations[\"current_player\"]\n", + " agent_output = agents[player_id].step(time_step)\n", + " action_list = [agent_output.action]\n", + " time_step = env.step(action_list) \n", + " # Episode is over, step all agents with final info state.\n", + " for agent in agents:\n", + " agent.step(time_step)\n", + " return expl_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "game = \"kuhn_poker\"\n", + "num_players = 2\n", + "env_configs = {\"players\": num_players}\n", + "num_train_episodes = int(3e6)\n", + "eval_every = 10000\n", + "hidden_layers_sizes = [128]\n", + "replay_buffer_capacity = int(2e5)\n", + "reservoir_buffer_capacity = int(2e6)\n", + "anticipatory_param = 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf_kuhn_result = tf_main(game, \n", + " env_configs,\n", + " num_train_episodes,\n", + " eval_every,\n", + " hidden_layers_sizes,\n", + " replay_buffer_capacity,\n", + " reservoir_buffer_capacity,\n", + " anticipatory_param)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pt_kuhn_result = pt_main(game, \n", + " env_configs,\n", + " num_train_episodes,\n", + " eval_every,\n", + " hidden_layers_sizes,\n", + " replay_buffer_capacity,\n", + " reservoir_buffer_capacity,\n", + " anticipatory_param)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "x = [i*1000 for i in range(len(tf_kuhn_result))]\n", + "\n", + "plt.plot(x, tf_kuhn_result, label='tensorflow')\n", + "plt.plot(x, pt_kuhn_result, label='pytorch')\n", + "plt.title('Kuhn Poker')\n", + "plt.xlabel('Episodes')\n", + "plt.ylabel('Exploitability')\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "game = \"leduc_poker\"\n", + "num_players = 2\n", + "env_configs = {\"players\": num_players}\n", + "num_train_episodes = int(3e6)\n", + "eval_every = 100000\n", + "hidden_layers_sizes = [128]\n", + "replay_buffer_capacity = int(2e5)\n", + "reservoir_buffer_capacity = int(2e6)\n", + "anticipatory_param = 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf_leduc_result = tf_main(game, \n", + " env_configs,\n", + " num_train_episodes,\n", + " eval_every,\n", + " hidden_layers_sizes,\n", + " replay_buffer_capacity,\n", + " reservoir_buffer_capacity,\n", + " anticipatory_param)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pt_leduc_result = pt_main(game, \n", + " env_configs,\n", + " num_train_episodes,\n", + " eval_every,\n", + " hidden_layers_sizes,\n", + " replay_buffer_capacity,\n", + " reservoir_buffer_capacity,\n", + " anticipatory_param)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x = [i * 10000 for i in range(len(tf_leduc_result))]\n", + "\n", + "plt.plot(x, tf_leduc_result, label='tensorflow')\n", + "plt.plot(x, pt_leduc_result, label='pytorch')\n", + "plt.title('Leduc Poker')\n", + "plt.xlabel('Episodes')\n", + "plt.ylabel('Exploitability')\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/open_spiel/python/CMakeLists.txt b/open_spiel/python/CMakeLists.txt index 6bbf03600e..cdd8989204 100644 --- a/open_spiel/python/CMakeLists.txt +++ b/open_spiel/python/CMakeLists.txt @@ -155,6 +155,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS} utils/spawn_test.py pytorch/rcfr_pytorch_test.py pytorch/dqn_pytorch_test.py + pytorch/nfsp_pytorch_test.py ) diff --git a/open_spiel/python/pytorch/dqn.py b/open_spiel/python/pytorch/dqn.py index 0be58589f8..adc392a945 100644 --- a/open_spiel/python/pytorch/dqn.py +++ b/open_spiel/python/pytorch/dqn.py @@ -19,8 +19,10 @@ from __future__ import print_function import collections +import math import random import numpy as np +from scipy.stats import truncnorm import torch import torch.nn as nn import torch.nn.functional as F @@ -87,6 +89,40 @@ def __iter__(self): return iter(self._data) +class SonnetLinear(nn.Module): + """A Sonnet linear module. + + Always includes biases and only supports ReLU activations. + """ + + def __init__(self, in_size, out_size, activate_relu=True): + """Creates a Sonnet linear layer. + + Args: + in_size: (int) number of inputs + out_size: (int) number of outputs + activate_relu: (bool) whether to include a ReLU activation layer + """ + super(SonnetLinear, self).__init__() + self._activate_relu = activate_relu + stddev = 1.0 / math.sqrt(in_size) + mean = 0 + lower = (-2 * stddev - mean) / stddev + upper = (2 * stddev - mean) / stddev + # Weight initialization inspired by Sonnet's Linear layer, + # which cites https://arxiv.org/abs/1502.03167v3 + # pytorch default: initialized from + # uniform(-sqrt(1/in_features), sqrt(1/in_features)) + self._weight = nn.Parameter(torch.Tensor( + truncnorm.rvs(lower, upper, loc=mean, scale=stddev, + size=[out_size, in_size]))) + self._bias = nn.Parameter(torch.zeros([out_size])) + + def forward(self, tensor): + y = F.linear(tensor, self._weight, self._bias) + return F.relu(y) if self._activate_relu else y + + class MLP(nn.Module): """A simple network built from nn.linear layers.""" @@ -108,14 +144,14 @@ def __init__(self, self._layers = [] # Hidden layers for size in hidden_sizes: - self._layers.append(nn.Linear(in_features=input_size, out_features=size)) - self._layers.append(nn.ReLU()) + self._layers.append(SonnetLinear(in_size=input_size, out_size=size)) input_size = size # Output layer self._layers.append( - nn.Linear(in_features=input_size, out_features=output_size)) - if activate_final: - self._layers.append(nn.ReLU()) + SonnetLinear( + in_size=input_size, + out_size=output_size, + activate_relu=activate_final)) self.model = nn.ModuleList(self._layers) diff --git a/open_spiel/python/pytorch/nfsp.py b/open_spiel/python/pytorch/nfsp.py new file mode 100644 index 0000000000..b51ec1e99b --- /dev/null +++ b/open_spiel/python/pytorch/nfsp.py @@ -0,0 +1,342 @@ +# Copyright 2019 DeepMind Technologies Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neural Fictitious Self-Play (NFSP) agent implemented in PyTorch. + +See the paper https://arxiv.org/abs/1603.01121 for more details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import contextlib +import enum +import os +import random +from absl import logging +import numpy as np + +import torch +import torch.nn.functional as F + +from open_spiel.python import rl_agent +from open_spiel.python.pytorch import dqn + + +Transition = collections.namedtuple( + "Transition", "info_state action_probs legal_actions_mask") + +ILLEGAL_ACTION_LOGITS_PENALTY = -1e9 + +MODE = enum.Enum("mode", "best_response average_policy") + + +class NFSP(rl_agent.AbstractAgent): + """NFSP Agent implementation in PyTorch. + + See open_spiel/python/examples/kuhn_nfsp.py for an usage example. + """ + + def __init__(self, + player_id, + state_representation_size, + num_actions, + hidden_layers_sizes, + reservoir_buffer_capacity, + anticipatory_param, + batch_size=128, + rl_learning_rate=0.01, + sl_learning_rate=0.01, + min_buffer_size_to_learn=1000, + learn_every=64, + optimizer_str="sgd", + **kwargs): + """Initialize the `NFSP` agent.""" + self.player_id = player_id + self._num_actions = num_actions + self._layer_sizes = hidden_layers_sizes + self._batch_size = batch_size + self._learn_every = learn_every + self._anticipatory_param = anticipatory_param + self._min_buffer_size_to_learn = min_buffer_size_to_learn + + self._reservoir_buffer = ReservoirBuffer(reservoir_buffer_capacity) + self._prev_timestep = None + self._prev_action = None + + # Step counter to keep track of learning. + self._step_counter = 0 + + # Inner RL agent + kwargs.update({ + "batch_size": batch_size, + "learning_rate": rl_learning_rate, + "learn_every": learn_every, + "min_buffer_size_to_learn": min_buffer_size_to_learn, + "optimizer_str": optimizer_str, + }) + self._rl_agent = dqn.DQN(player_id, state_representation_size, + num_actions, hidden_layers_sizes, **kwargs) + + # Keep track of the last training loss achieved in an update step. + self._last_rl_loss_value = lambda: self._rl_agent.loss + self._last_sl_loss_value = None + + # Average policy network. + self._avg_network = dqn.MLP(state_representation_size, + self._layer_sizes, num_actions) + + self._savers = [ + ("q_network", self._rl_agent._q_network), + ("avg_network", self._avg_network) + ] + + if optimizer_str == "adam": + self.optimizer = torch.optim.Adam( + self._avg_network.parameters(), lr=sl_learning_rate) + elif optimizer_str == "sgd": + self.optimizer = torch.optim.SGD( + self._avg_network.parameters(), lr=sl_learning_rate) + else: + raise ValueError("Not implemented. Choose from ['adam', 'sgd'].") + + self._sample_episode_policy() + + @contextlib.contextmanager + def temp_mode_as(self, mode): + """Context manager to temporarily overwrite the mode.""" + previous_mode = self._mode + self._mode = mode + yield + self._mode = previous_mode + + def _sample_episode_policy(self): + if np.random.rand() < self._anticipatory_param: + self._mode = MODE.best_response + else: + self._mode = MODE.average_policy + + def _act(self, info_state, legal_actions): + info_state = np.reshape(info_state, [1, -1]) + action_values = self._avg_network(torch.Tensor(info_state)) + action_probs = F.softmax(action_values, dim=1).detach() + + self._last_action_values = action_values[0] + # Remove illegal actions, normalize probs + probs = np.zeros(self._num_actions) + probs[legal_actions] = action_probs[0][legal_actions] + probs /= sum(probs) + action = np.random.choice(len(probs), p=probs) + return action, probs + + @property + def mode(self): + return self._mode + + @property + def loss(self): + return (self._last_sl_loss_value, self._last_rl_loss_value().detach()) + + def step(self, time_step, is_evaluation=False): + """Returns the action to be taken and updates the Q-networks if needed. + + Args: + time_step: an instance of rl_environment.TimeStep. + is_evaluation: bool, whether this is a training or evaluation call. + + Returns: + A `rl_agent.StepOutput` containing the action probs and chosen action. + """ + if self._mode == MODE.best_response: + agent_output = self._rl_agent.step(time_step, is_evaluation) + if not is_evaluation and not time_step.last(): + self._add_transition(time_step, agent_output) + + elif self._mode == MODE.average_policy: + # Act step: don't act at terminal info states. + if not time_step.last(): + info_state = time_step.observations["info_state"][self.player_id] + legal_actions = time_step.observations["legal_actions"][self.player_id] + action, probs = self._act(info_state, legal_actions) + agent_output = rl_agent.StepOutput(action=action, probs=probs) + + if self._prev_timestep and not is_evaluation: + self._rl_agent.add_transition(self._prev_timestep, self._prev_action, + time_step) + else: + raise ValueError("Invalid mode ({})".format(self._mode)) + + if not is_evaluation: + self._step_counter += 1 + + if self._step_counter % self._learn_every == 0: + self._last_sl_loss_value = self._learn() + # If learn step not triggered by rl policy, learn. + if self._mode == MODE.average_policy: + self._rl_agent.learn() + + # Prepare for the next episode. + if time_step.last(): + self._sample_episode_policy() + self._prev_timestep = None + self._prev_action = None + return + else: + self._prev_timestep = time_step + self._prev_action = agent_output.action + + return agent_output + + def _add_transition(self, time_step, agent_output): + """Adds the new transition using `time_step` to the reservoir buffer. + + Transitions are in the form (time_step, agent_output.probs, legal_mask). + + Args: + time_step: an instance of rl_environment.TimeStep. + agent_output: an instance of rl_agent.StepOutput. + """ + legal_actions = time_step.observations["legal_actions"][self.player_id] + legal_actions_mask = np.zeros(self._num_actions) + legal_actions_mask[legal_actions] = 1.0 + transition = Transition( + info_state=(time_step.observations["info_state"][self.player_id][:]), + action_probs=agent_output.probs, + legal_actions_mask=legal_actions_mask) + self._reservoir_buffer.add(transition) + + def _learn(self): + """Compute the loss on sampled transitions and perform a avg-network update. + + If there are not enough elements in the buffer, no loss is computed and + `None` is returned instead. + + Returns: + The average loss obtained on this batch of transitions or `None`. + """ + if (len(self._reservoir_buffer) < self._batch_size or + len(self._reservoir_buffer) < self._min_buffer_size_to_learn): + return None + + transitions = self._reservoir_buffer.sample(self._batch_size) + info_states = torch.Tensor([t.info_state for t in transitions]) + action_probs = torch.Tensor([t.action_probs for t in transitions]) + + self.optimizer.zero_grad() + loss = F.cross_entropy(self._avg_network(info_states), + torch.max(action_probs, dim=1)[1]) + loss.backward() + self.optimizer.step() + return loss.detach() + + def _full_checkpoint_name(self, checkpoint_dir, name): + checkpoint_filename = "_".join([name, "pid" + str(self.player_id)]) + return os.path.join(checkpoint_dir, checkpoint_filename) + + def _latest_checkpoint_filename(self, name): + checkpoint_filename = "_".join([name, "pid" + str(self.player_id)]) + return checkpoint_filename + "_latest" + + def save(self, checkpoint_dir): + """Saves the average policy network and the inner RL agent's q-network. + + Note that this does not save the experience replay buffers and should + only be used to restore the agent's policy, not resume training. + + Args: + checkpoint_dir: directory where checkpoints will be saved. + """ + for name, model in self._savers: + path = self._full_checkpoint_name(checkpoint_dir, name) + torch.save(model.state_dict(), path) + logging.info("Saved to path: %s", path) + + def has_checkpoint(self, checkpoint_dir): + for name, _ in self._savers: + path = self._full_checkpoint_name(checkpoint_dir, name) + if os.path.exists(path): + return True + return False + + def restore(self, checkpoint_dir): + """Restores the average policy network and the inner RL agent's q-network. + + Note that this does not restore the experience replay buffers and should + only be used to restore the agent's policy, not resume training. + + Args: + checkpoint_dir: directory from which checkpoints will be restored. + """ + for name, model in self._savers: + full_checkpoint_dir = self._full_checkpoint_name(checkpoint_dir, name) + logging.info("Restoring checkpoint: %s", full_checkpoint_dir) + model.load_state_dict(torch.load(full_checkpoint_dir)) + + +class ReservoirBuffer(object): + """Allows uniform sampling over a stream of data. + + This class supports the storage of arbitrary elements, such as observation + tensors, integer actions, etc. + + See https://en.wikipedia.org/wiki/Reservoir_sampling for more details. + """ + + def __init__(self, reservoir_buffer_capacity): + self._reservoir_buffer_capacity = reservoir_buffer_capacity + self._data = [] + self._add_calls = 0 + + def add(self, element): + """Potentially adds `element` to the reservoir buffer. + + Args: + element: data to be added to the reservoir buffer. + """ + if len(self._data) < self._reservoir_buffer_capacity: + self._data.append(element) + else: + idx = np.random.randint(0, self._add_calls + 1) + if idx < self._reservoir_buffer_capacity: + self._data[idx] = element + self._add_calls += 1 + + def sample(self, num_samples): + """Returns `num_samples` uniformly sampled from the buffer. + + Args: + num_samples: `int`, number of samples to draw. + + Returns: + An iterable over `num_samples` random elements of the buffer. + + Raises: + ValueError: If there are less than `num_samples` elements in the buffer + """ + if len(self._data) < num_samples: + raise ValueError("{} elements could not be sampled from size {}".format( + num_samples, len(self._data))) + return random.sample(self._data, num_samples) + + def clear(self): + self._data = [] + self._add_calls = 0 + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) diff --git a/open_spiel/python/pytorch/nfsp_pytorch_test.py b/open_spiel/python/pytorch/nfsp_pytorch_test.py new file mode 100644 index 0000000000..f4e21e400a --- /dev/null +++ b/open_spiel/python/pytorch/nfsp_pytorch_test.py @@ -0,0 +1,112 @@ +# Copyright 2019 DeepMind Technologies Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for open_spiel.python.algorithms.nfsp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from scipy import stats + +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase + +from open_spiel.python import rl_environment +from open_spiel.python.pytorch import nfsp + + +class NFSPTest(TestCase): + + def test_run_kuhn(self): + env = rl_environment.Environment("kuhn_poker") + state_size = env.observation_spec()["info_state"][0] + num_actions = env.action_spec()["num_actions"] + + agents = [ + nfsp.NFSP( # pylint: disable=g-complex-comprehension + player_id, + state_representation_size=state_size, + num_actions=num_actions, + hidden_layers_sizes=[16], + reservoir_buffer_capacity=10, + anticipatory_param=0.1) for player_id in [0, 1] + ] + for unused_ep in range(10): + time_step = env.reset() + while not time_step.last(): + current_player = time_step.observations["current_player"] + current_agent = agents[current_player] + agent_output = current_agent.step(time_step) + time_step = env.step([agent_output.action]) + for agent in agents: + agent.step(time_step) + + +class ReservoirBufferTest(TestCase): + + def test_reservoir_buffer_add(self): + reservoir_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=10) + self.assertEqual(len(reservoir_buffer), 0) + reservoir_buffer.add("entry1") + self.assertEqual(len(reservoir_buffer), 1) + reservoir_buffer.add("entry2") + self.assertEqual(len(reservoir_buffer), 2) + + self.assertIn("entry1", reservoir_buffer) + self.assertIn("entry2", reservoir_buffer) + + def test_reservoir_buffer_max_capacity(self): + reservoir_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=2) + reservoir_buffer.add("entry1") + reservoir_buffer.add("entry2") + reservoir_buffer.add("entry3") + + self.assertEqual(len(reservoir_buffer), 2) + + def test_reservoir_uniform(self): + size = 10 + max_value = 100 + num_trials = 1000 + expected_count = 1. / max_value * size * num_trials + + reservoir_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=size) + counter = collections.Counter() + for _ in range(num_trials): + reservoir_buffer.clear() + for idx in range(max_value): + reservoir_buffer.add(idx) + data = reservoir_buffer.sample(size) + counter.update(data) + # Tests the null hypothesis (H0) that data has the given frequencies. + # We reject the null hypothesis if we get a p-value below our threshold. + pvalue = stats.chisquare(list(counter.values()), expected_count).pvalue + self.assertGreater(pvalue, 0.05) # We cannot reject H0. + + def test_reservoir_buffer_sample(self): + replay_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=3) + replay_buffer.add("entry1") + replay_buffer.add("entry2") + replay_buffer.add("entry3") + + samples = replay_buffer.sample(3) + + self.assertIn("entry1", samples) + self.assertIn("entry2", samples) + self.assertIn("entry3", samples) + + +if __name__ == "__main__": + run_tests()