-
Notifications
You must be signed in to change notification settings - Fork 3
/
policies.py
executable file
·124 lines (98 loc) · 4.45 KB
/
policies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
import math
import cherry as ch
import torch
from torch import nn
from torch.distributions import Normal, Categorical
EPSILON = 1e-6
def linear_init(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
module.bias.data.zero_()
return module
class CaviaDiagNormalPolicy(nn.Module):
def __init__(self, input_size, output_size, hiddens=None, activation='relu', num_context_params=2, device='cpu'):
super(CaviaDiagNormalPolicy, self).__init__()
self.device = device
if hiddens is None:
hiddens = [100, 100]
if activation == 'relu':
activation = nn.ReLU
elif activation == 'tanh':
activation = nn.Tanh
layers = [linear_init(nn.Linear(input_size+num_context_params, hiddens[0])), activation()]
for i, o in zip(hiddens[:-1], hiddens[1:]):
layers.append(linear_init(nn.Linear(i, o)))
layers.append(activation())
layers.append(linear_init(nn.Linear(hiddens[-1], output_size)))
self.num_context_params = num_context_params
self.context_params = torch.zeros(self.num_context_params, requires_grad=True).to(self.device)
self.mean = nn.Sequential(*layers).to(self.device)
self.sigma = nn.Parameter(torch.Tensor(output_size)).to(self.device)
self.sigma.data.fill_(math.log(1))
def density(self, state):
state = state.to(self.device, non_blocking=True)
# concatenate context parameters to input
state = torch.cat((state, self.context_params.expand(state.shape[:-1] + self.context_params.shape)),
dim=len(state.shape) - 1)
loc = self.mean(state)
scale = torch.exp(torch.clamp(self.sigma, min=math.log(EPSILON)))
return Normal(loc=loc, scale=scale)
def log_prob(self, state, action):
density = self.density(state)
return density.log_prob(action).mean(dim=1, keepdim=True)
def forward(self, state):
density = self.density(state)
action = density.sample()
return action
def reset_context(self):
self.context_params[:] = 0 # torch.zeros(self.num_context_params, requires_grad=True).to(self.device)
class DiagNormalPolicy(nn.Module):
def __init__(self, input_size, output_size, hiddens=None, activation='relu', device='cpu'):
super(DiagNormalPolicy, self).__init__()
self.device = device
if hiddens is None:
hiddens = [100, 100]
if activation == 'relu':
activation = nn.ReLU
elif activation == 'tanh':
activation = nn.Tanh
layers = [linear_init(nn.Linear(input_size, hiddens[0])), activation()]
for i, o in zip(hiddens[:-1], hiddens[1:]):
layers.append(linear_init(nn.Linear(i, o)))
layers.append(activation())
layers.append(linear_init(nn.Linear(hiddens[-1], output_size)))
self.mean = nn.Sequential(*layers)
self.sigma = nn.Parameter(torch.Tensor(output_size))
self.sigma.data.fill_(math.log(1))
def density(self, state):
state = state.to(self.device, non_blocking=True)
loc = self.mean(state)
scale = torch.exp(torch.clamp(self.sigma, min=math.log(EPSILON)))
return Normal(loc=loc, scale=scale)
def log_prob(self, state, action):
density = self.density(state)
return density.log_prob(action).mean(dim=1, keepdim=True)
def forward(self, state):
density = self.density(state)
action = density.sample()
return action
class CategoricalPolicy(nn.Module):
def __init__(self, input_size, output_size, hiddens=None):
super(CategoricalPolicy, self).__init__()
if hiddens is None:
hiddens = [100, 100]
layers = [linear_init(nn.Linear(input_size, hiddens[0])), nn.ReLU()]
for i, o in zip(hiddens[:-1], hiddens[1:]):
layers.append(linear_init(nn.Linear(i, o)))
layers.append(nn.ReLU())
layers.append(linear_init(nn.Linear(hiddens[-1], output_size)))
self.mean = nn.Sequential(*layers)
self.input_size = input_size
def forward(self, state):
state = ch.onehot(state, dim=self.input_size)
loc = self.mean(state)
density = Categorical(logits=loc)
action = density.sample()
log_prob = density.log_prob(action).mean().view(-1, 1).detach()
return action, {'density': density, 'log_prob': log_prob}