-
Notifications
You must be signed in to change notification settings - Fork 34
/
Agent.py
97 lines (78 loc) · 3.73 KB
/
Agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from copy import deepcopy
from typing import List
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Adam
class Agent:
"""Agent that can interact with environment from pettingzoo"""
def __init__(self, obs_dim, act_dim, global_obs_dim, actor_lr, critic_lr):
self.actor = MLPNetwork(obs_dim, act_dim)
# critic input all the observations and actions
# if there are 3 agents for example, the input for critic is (obs1, obs2, obs3, act1, act2, act3)
self.critic = MLPNetwork(global_obs_dim, 1)
self.actor_optimizer = Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=critic_lr)
self.target_actor = deepcopy(self.actor)
self.target_critic = deepcopy(self.critic)
@staticmethod
def gumbel_softmax(logits, tau=1.0, eps=1e-20):
# NOTE that there is a function like this implemented in PyTorch(torch.nn.functional.gumbel_softmax),
# but as mention in the doc, it may be removed in the future, so i implement it myself
epsilon = torch.rand_like(logits)
logits += -torch.log(-torch.log(epsilon + eps) + eps)
return F.softmax(logits / tau, dim=-1)
def action(self, obs, model_out=False):
# this method is called in the following two cases:
# a) interact with the environment
# b) calculate action when update actor, where input(obs) is sampled from replay buffer with size:
# torch.Size([batch_size, state_dim])
logits = self.actor(obs) # torch.Size([batch_size, action_size])
# action = self.gumbel_softmax(logits)
action = F.gumbel_softmax(logits, hard=True)
if model_out:
return action, logits
return action
def target_action(self, obs):
# when calculate target critic value in MADDPG,
# we use target actor to get next action given next states,
# which is sampled from replay buffer with size torch.Size([batch_size, state_dim])
logits = self.target_actor(obs) # torch.Size([batch_size, action_size])
# action = self.gumbel_softmax(logits)
action = F.gumbel_softmax(logits, hard=True)
return action.squeeze(0).detach()
def critic_value(self, state_list: List[Tensor], act_list: List[Tensor]):
x = torch.cat(state_list + act_list, 1)
return self.critic(x).squeeze(1) # tensor with a given length
def target_critic_value(self, state_list: List[Tensor], act_list: List[Tensor]):
x = torch.cat(state_list + act_list, 1)
return self.target_critic(x).squeeze(1) # tensor with a given length
def update_actor(self, loss):
self.actor_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optimizer.step()
def update_critic(self, loss):
self.critic_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optimizer.step()
class MLPNetwork(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=64, non_linear=nn.ReLU()):
super(MLPNetwork, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
non_linear,
nn.Linear(hidden_dim, hidden_dim),
non_linear,
nn.Linear(hidden_dim, out_dim),
).apply(self.init)
@staticmethod
def init(m):
"""init parameter of the module"""
gain = nn.init.calculate_gain('relu')
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=gain)
m.bias.data.fill_(0.01)
def forward(self, x):
return self.net(x)