-
Notifications
You must be signed in to change notification settings - Fork 0
/
baseline.py
60 lines (50 loc) · 2.25 KB
/
baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
class LinearFeatureBaseline(nn.Module):
"""Linear baseline based on handcrafted features, as described in [1]
(Supplementary Material 2).
[1] Yan Duan, Xi Chen, Rein Houthooft, John Schulman, Pieter Abbeel,
"Benchmarking Deep Reinforcement Learning for Continuous Control", 2016
(https://arxiv.org/abs/1604.06778)
"""
def __init__(self, input_size, reg_coeff=1e-5):
super(LinearFeatureBaseline, self).__init__()
self.input_size = input_size
self._reg_coeff = reg_coeff
self.linear = nn.Linear(self.feature_size, 1, bias=False)
self.linear.weight.data.zero_()
@property
def feature_size(self):
return 2 * self.input_size + 4
def _feature(self, episodes):
ones = episodes.mask.unsqueeze(2)
observations = episodes.observations * ones
cum_sum = torch.cumsum(ones, dim=0) * ones
al = cum_sum / 100.0
return torch.cat([observations, observations ** 2, al, al ** 2, al ** 3, ones], dim=2)
def fit(self, episodes):
# sequence_length * batch_size x feature_size
featmat = self._feature(episodes).view(-1, self.feature_size)
# sequence_length * batch_size x 1
returns = episodes.returns.view(-1, 1)
reg_coeff = self._reg_coeff
eye = torch.eye(self.feature_size, dtype=torch.float32,
device=self.linear.weight.device)
for _ in range(5):
try:
coeffs, _ = torch.lstsq(
torch.matmul(featmat.t(), returns),
torch.matmul(featmat.t(), featmat) + reg_coeff * eye
)
break
except RuntimeError:
reg_coeff += 10
else:
raise RuntimeError('Unable to solve the normal equations in '
'`LinearFeatureBaseline`. The matrix X^T*X (with X the design '
'matrix) is not full-rank, regardless of the regularization '
'(maximum regularization: {0}).'.format(reg_coeff))
self.linear.weight.data = coeffs.data.t()
def forward(self, episodes):
features = self._feature(episodes)
return self.linear(features)