-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
156 lines (139 loc) · 5.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn
from torch.distributions import Categorical, Beta, MixtureSameFamily
class NeuralNet(nn.Module):
def __init__(self, dim=2, hidden_dim=64, n_hidden=2, torso=None, output_dim=3):
super().__init__()
self.dim = dim
self.n_hidden = n_hidden
self.output_dim = output_dim
if torso is not None:
self.torso = torso
else:
self.torso = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ELU(),
*[
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
)
for _ in range(n_hidden)
],
)
self.output_layer = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
out = self.output_layer(self.torso(x))
return out
class CirclePF(NeuralNet):
def __init__(
self,
hidden_dim=64,
n_hidden=2,
n_components_s0=1,
n_components=1,
beta_min=0.1,
beta_max=2.0,
):
output_dim = 1 + 3 * n_components
super().__init__(
dim=2, hidden_dim=hidden_dim, n_hidden=n_hidden, output_dim=output_dim
)
# The following parameters are for PF(. | s0)
self.PFs0 = nn.ParameterDict(
{
"log_alpha_r": nn.Parameter(torch.zeros(n_components_s0)),
"log_alpha_theta": nn.Parameter(torch.zeros(n_components_s0)),
"log_beta_r": nn.Parameter(torch.zeros(n_components_s0)),
"log_beta_theta": nn.Parameter(torch.zeros(n_components_s0)),
"logits": nn.Parameter(torch.zeros(n_components_s0)),
}
)
self.n_components = n_components
self.n_components_s0 = n_components_s0
self.beta_min = beta_min
self.beta_max = beta_max
def forward(self, x):
out = super().forward(x)
pre_sigmoid_exit = out[..., 0]
mixture_logits = out[..., 1 : 1 + self.n_components]
log_alpha = out[..., 1 + self.n_components : 1 + 2 * self.n_components]
log_beta = out[..., 1 + 2 * self.n_components : 1 + 3 * self.n_components]
exit_proba = torch.sigmoid(pre_sigmoid_exit)
return (
exit_proba,
mixture_logits,
self.beta_max * torch.sigmoid(log_alpha) + self.beta_min,
self.beta_max * torch.sigmoid(log_beta) + self.beta_min,
)
def to_dist(self, x):
if torch.all(x[0] == 0.0):
assert torch.all(
x == 0.0
) # If one of the states is s0, all of them must be
alpha_r = self.PFs0["log_alpha_r"]
alpha_r = self.beta_max * torch.sigmoid(alpha_r) + self.beta_min
alpha_theta = self.PFs0["log_alpha_theta"]
alpha_theta = self.beta_max * torch.sigmoid(alpha_theta) + self.beta_min
beta_r = self.PFs0["log_beta_r"]
beta_r = self.beta_max * torch.sigmoid(beta_r) + self.beta_min
beta_theta = self.PFs0["log_beta_theta"]
beta_theta = self.beta_max * torch.sigmoid(beta_theta) + self.beta_min
logits = self.PFs0["logits"]
dist_r = MixtureSameFamily(
Categorical(logits=logits),
Beta(alpha_r, beta_r),
)
dist_theta = MixtureSameFamily(
Categorical(logits=logits),
Beta(alpha_theta, beta_theta),
)
return dist_r, dist_theta
# Otherwise, we use the neural network
exit_proba, mixture_logits, alpha, beta = self.forward(x)
dist = MixtureSameFamily(
Categorical(logits=mixture_logits),
Beta(alpha, beta),
)
return exit_proba, dist
class CirclePB(NeuralNet):
def __init__(
self,
hidden_dim=64,
n_hidden=2,
torso=None,
uniform=False,
n_components=1,
beta_min=0.1,
beta_max=2.0,
):
output_dim = 3 * n_components
super().__init__(
dim=2, hidden_dim=hidden_dim, n_hidden=n_hidden, output_dim=output_dim
)
if torso is not None:
self.torso = torso
self.uniform = uniform
self.n_components = n_components
self.beta_min = beta_min
self.beta_max = beta_max
def forward(self, x):
# x is a batch of states, a tensor of shape (batch_size, dim) with dim == 2
out = super().forward(x)
mixture_logits = out[:, 0 : self.n_components]
log_alpha = out[:, self.n_components : 2 * self.n_components]
log_beta = out[:, 2 * self.n_components : 3 * self.n_components]
return (
mixture_logits,
self.beta_max * torch.sigmoid(log_alpha) + self.beta_min,
self.beta_max * torch.sigmoid(log_beta) + self.beta_min,
)
def to_dist(self, x):
if self.uniform:
return Beta(torch.ones(x.shape[0]), torch.ones(x.shape[0]))
mixture_logits, alpha, beta = self.forward(x)
dist = MixtureSameFamily(
Categorical(logits=mixture_logits),
Beta(alpha, beta),
)
return dist