Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
10a55ff by Asugawara <asgasw@gmail.com>:

add nfsp

--
7908ccc by Asugawara <asgasw@gmail.com>:

add Sonnet Linear Module

--
5671c9f by Asugawara <asgasw@gmail.com>:

action_probs: LongTensor to Tensor

--
b6b9d7d by Asugawara <asgasw@gmail.com>:

remove image and progress

COPYBARA_INTEGRATE_REVIEW=#450 from Asugawara:nfsp_pytorch b6b9d7d
PiperOrigin-RevId: 345889227
Change-Id: Ib5558b3e05f4cfe96c1a9854a6956100b03ee2d4
  • Loading branch information
Asugawara authored and open_spiel@google.com committed Dec 6, 2020
1 parent 9a343a1 commit ffe263b
Show file tree
Hide file tree
Showing 5 changed files with 803 additions and 5 deletions.
307 changes: 307 additions & 0 deletions open_spiel/colabs/research_nfsp_tf_pt.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from absl import logging\n",
"import tensorflow.compat.v1 as tf\n",
"\n",
"from open_spiel.python import policy\n",
"from open_spiel.python import rl_environment\n",
"from open_spiel.python.algorithms import exploitability\n",
"from open_spiel.python.algorithms import nfsp\n",
"from open_spiel.python.pytorch import nfsp as nfsp_pt\n",
"\n",
"class NFSPPolicies(policy.Policy):\n",
" \"\"\"Joint policy to be evaluated.\"\"\"\n",
"\n",
" def __init__(self, env, nfsp_policies, mode):\n",
" game = env.game\n",
" player_ids = [0, 1]\n",
" super(NFSPPolicies, self).__init__(game, player_ids)\n",
" self._policies = nfsp_policies\n",
" self._mode = mode\n",
" self._obs = {\"info_state\": [None, None], \"legal_actions\": [None, None]}\n",
"\n",
" def action_probabilities(self, state, player_id=None):\n",
" cur_player = state.current_player()\n",
" legal_actions = state.legal_actions(cur_player)\n",
"\n",
" self._obs[\"current_player\"] = cur_player\n",
" self._obs[\"info_state\"][cur_player] = (\n",
" state.information_state_tensor(cur_player))\n",
" self._obs[\"legal_actions\"][cur_player] = legal_actions\n",
"\n",
" info_state = rl_environment.TimeStep(\n",
" observations=self._obs, rewards=None, discounts=None, step_type=None)\n",
"\n",
" with self._policies[cur_player].temp_mode_as(self._mode):\n",
" p = self._policies[cur_player].step(info_state, is_evaluation=True).probs\n",
" prob_dict = {action: p[action] for action in legal_actions}\n",
" return prob_dict\n",
"\n",
"\n",
"def tf_main(game,\n",
" env_config,\n",
" num_train_episodes,\n",
" eval_every,\n",
" hidden_layers_sizes,\n",
" replay_buffer_capacity,\n",
" reservoir_buffer_capacity,\n",
" anticipatory_param):\n",
" env = rl_environment.Environment(game, **env_configs)\n",
" info_state_size = env.observation_spec()[\"info_state\"][0]\n",
" num_actions = env.action_spec()[\"num_actions\"]\n",
"\n",
" hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n",
" kwargs = {\n",
" \"replay_buffer_capacity\": replay_buffer_capacity,\n",
" \"epsilon_decay_duration\": num_train_episodes,\n",
" \"epsilon_start\": 0.06,\n",
" \"epsilon_end\": 0.001,\n",
" }\n",
" expl_list = []\n",
" with tf.Session() as sess:\n",
" # pylint: disable=g-complex-comprehension\n",
" agents = [\n",
" nfsp.NFSP(sess, idx, info_state_size, num_actions, hidden_layers_sizes,\n",
" reservoir_buffer_capacity, anticipatory_param,\n",
" **kwargs) for idx in range(num_players)\n",
" ]\n",
" expl_policies_avg = NFSPPolicies(env, agents, nfsp.MODE.average_policy)\n",
"\n",
" sess.run(tf.global_variables_initializer())\n",
" for ep in range(num_train_episodes):\n",
" if (ep + 1) % eval_every == 0:\n",
" losses = [agent.loss for agent in agents]\n",
" print(\"Losses: %s\" %losses)\n",
" expl = exploitability.exploitability(env.game, expl_policies_avg)\n",
" expl_list.append(expl)\n",
" print(\"[%s] Exploitability AVG %s\" %(ep + 1, expl))\n",
" print(\"_____________________________________________\")\n",
"\n",
" time_step = env.reset()\n",
" while not time_step.last():\n",
" player_id = time_step.observations[\"current_player\"]\n",
" agent_output = agents[player_id].step(time_step)\n",
" action_list = [agent_output.action]\n",
" time_step = env.step(action_list)\n",
"\n",
" # Episode is over, step all agents with final info state.\n",
" for agent in agents:\n",
" agent.step(time_step)\n",
" return expl_list\n",
" \n",
"def pt_main(game,\n",
" env_config,\n",
" num_train_episodes,\n",
" eval_every,\n",
" hidden_layers_sizes,\n",
" replay_buffer_capacity,\n",
" reservoir_buffer_capacity,\n",
" anticipatory_param):\n",
" env = rl_environment.Environment(game, **env_configs)\n",
" info_state_size = env.observation_spec()[\"info_state\"][0]\n",
" num_actions = env.action_spec()[\"num_actions\"]\n",
"\n",
" hidden_layers_sizes = [int(l) for l in hidden_layers_sizes]\n",
" kwargs = {\n",
" \"replay_buffer_capacity\": replay_buffer_capacity,\n",
" \"epsilon_decay_duration\": num_train_episodes,\n",
" \"epsilon_start\": 0.06,\n",
" \"epsilon_end\": 0.001,\n",
" }\n",
" expl_list = []\n",
" agents = [\n",
" nfsp_pt.NFSP(idx, info_state_size, num_actions, hidden_layers_sizes,\n",
" reservoir_buffer_capacity, anticipatory_param,\n",
" **kwargs) for idx in range(num_players)\n",
" ]\n",
" expl_policies_avg = NFSPPolicies(env, agents, nfsp_pt.MODE.average_policy) \n",
" for ep in range(num_train_episodes):\n",
" if (ep + 1) % eval_every == 0:\n",
" losses = [agent.loss.item() for agent in agents]\n",
" print(\"Losses: %s\" %losses)\n",
" expl = exploitability.exploitability(env.game, expl_policies_avg)\n",
" expl_list.append(expl)\n",
" print(\"[%s] Exploitability AVG %s\" %(ep + 1, expl))\n",
" print(\"_____________________________________________\") \n",
" time_step = env.reset()\n",
" while not time_step.last():\n",
" player_id = time_step.observations[\"current_player\"]\n",
" agent_output = agents[player_id].step(time_step)\n",
" action_list = [agent_output.action]\n",
" time_step = env.step(action_list) \n",
" # Episode is over, step all agents with final info state.\n",
" for agent in agents:\n",
" agent.step(time_step)\n",
" return expl_list"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"game = \"kuhn_poker\"\n",
"num_players = 2\n",
"env_configs = {\"players\": num_players}\n",
"num_train_episodes = int(3e6)\n",
"eval_every = 10000\n",
"hidden_layers_sizes = [128]\n",
"replay_buffer_capacity = int(2e5)\n",
"reservoir_buffer_capacity = int(2e6)\n",
"anticipatory_param = 0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tf_kuhn_result = tf_main(game, \n",
" env_configs,\n",
" num_train_episodes,\n",
" eval_every,\n",
" hidden_layers_sizes,\n",
" replay_buffer_capacity,\n",
" reservoir_buffer_capacity,\n",
" anticipatory_param)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pt_kuhn_result = pt_main(game, \n",
" env_configs,\n",
" num_train_episodes,\n",
" eval_every,\n",
" hidden_layers_sizes,\n",
" replay_buffer_capacity,\n",
" reservoir_buffer_capacity,\n",
" anticipatory_param)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"x = [i*1000 for i in range(len(tf_kuhn_result))]\n",
"\n",
"plt.plot(x, tf_kuhn_result, label='tensorflow')\n",
"plt.plot(x, pt_kuhn_result, label='pytorch')\n",
"plt.title('Kuhn Poker')\n",
"plt.xlabel('Episodes')\n",
"plt.ylabel('Exploitability')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"game = \"leduc_poker\"\n",
"num_players = 2\n",
"env_configs = {\"players\": num_players}\n",
"num_train_episodes = int(3e6)\n",
"eval_every = 100000\n",
"hidden_layers_sizes = [128]\n",
"replay_buffer_capacity = int(2e5)\n",
"reservoir_buffer_capacity = int(2e6)\n",
"anticipatory_param = 0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tf_leduc_result = tf_main(game, \n",
" env_configs,\n",
" num_train_episodes,\n",
" eval_every,\n",
" hidden_layers_sizes,\n",
" replay_buffer_capacity,\n",
" reservoir_buffer_capacity,\n",
" anticipatory_param)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pt_leduc_result = pt_main(game, \n",
" env_configs,\n",
" num_train_episodes,\n",
" eval_every,\n",
" hidden_layers_sizes,\n",
" replay_buffer_capacity,\n",
" reservoir_buffer_capacity,\n",
" anticipatory_param)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = [i * 10000 for i in range(len(tf_leduc_result))]\n",
"\n",
"plt.plot(x, tf_leduc_result, label='tensorflow')\n",
"plt.plot(x, pt_leduc_result, label='pytorch')\n",
"plt.title('Leduc Poker')\n",
"plt.xlabel('Episodes')\n",
"plt.ylabel('Exploitability')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions open_spiel/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS}
utils/spawn_test.py
pytorch/rcfr_pytorch_test.py
pytorch/dqn_pytorch_test.py
pytorch/nfsp_pytorch_test.py
)


Expand Down
46 changes: 41 additions & 5 deletions open_spiel/python/pytorch/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from __future__ import print_function

import collections
import math
import random
import numpy as np
from scipy.stats import truncnorm
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -87,6 +89,40 @@ def __iter__(self):
return iter(self._data)


class SonnetLinear(nn.Module):
"""A Sonnet linear module.
Always includes biases and only supports ReLU activations.
"""

def __init__(self, in_size, out_size, activate_relu=True):
"""Creates a Sonnet linear layer.
Args:
in_size: (int) number of inputs
out_size: (int) number of outputs
activate_relu: (bool) whether to include a ReLU activation layer
"""
super(SonnetLinear, self).__init__()
self._activate_relu = activate_relu
stddev = 1.0 / math.sqrt(in_size)
mean = 0
lower = (-2 * stddev - mean) / stddev
upper = (2 * stddev - mean) / stddev
# Weight initialization inspired by Sonnet's Linear layer,
# which cites https://arxiv.org/abs/1502.03167v3
# pytorch default: initialized from
# uniform(-sqrt(1/in_features), sqrt(1/in_features))
self._weight = nn.Parameter(torch.Tensor(
truncnorm.rvs(lower, upper, loc=mean, scale=stddev,
size=[out_size, in_size])))
self._bias = nn.Parameter(torch.zeros([out_size]))

def forward(self, tensor):
y = F.linear(tensor, self._weight, self._bias)
return F.relu(y) if self._activate_relu else y


class MLP(nn.Module):
"""A simple network built from nn.linear layers."""

Expand All @@ -108,14 +144,14 @@ def __init__(self,
self._layers = []
# Hidden layers
for size in hidden_sizes:
self._layers.append(nn.Linear(in_features=input_size, out_features=size))
self._layers.append(nn.ReLU())
self._layers.append(SonnetLinear(in_size=input_size, out_size=size))
input_size = size
# Output layer
self._layers.append(
nn.Linear(in_features=input_size, out_features=output_size))
if activate_final:
self._layers.append(nn.ReLU())
SonnetLinear(
in_size=input_size,
out_size=output_size,
activate_relu=activate_final))

self.model = nn.ModuleList(self._layers)

Expand Down
Loading

0 comments on commit ffe263b

Please sign in to comment.