-
Notifications
You must be signed in to change notification settings - Fork 2
/
buffer.py
92 lines (74 loc) · 3.04 KB
/
buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# =====================================
# @Time : 2020/6/10
# @Author : Yang Guan (Tsinghua Univ.)
# @FileName: buffer.py
# =====================================
import logging
import random
import numpy as np
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
from utils.segment_tree import SumSegmentTree, MinSegmentTree
class ReplayBuffer(object):
def __init__(self, args, buffer_id):
"""Create Prioritized Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
"""
self.args = args
self.buffer_id = buffer_id
self._storage = []
self._maxsize = self.args.max_buffer_size
self._next_idx = 0
self.replay_starts = self.args.replay_starts
self.replay_batch_size = self.args.replay_batch_size
self.stats = {}
self.replay_times = 0
logger.info('Buffer initialized')
def get_stats(self):
self.stats.update(dict(storage=len(self._storage)))
return self.stats
def __len__(self):
return len(self._storage)
def add(self, obs_t, action, reward, obs_tp1, done, ref_index, weight):
data = (obs_t, action, reward, obs_tp1, done, ref_index)
if self._next_idx >= len(self._storage):
self._storage.append(data)
else:
self._storage[self._next_idx] = data
self._next_idx = (self._next_idx + 1) % self._maxsize
def _encode_sample(self, idxes):
obses_t, actions, rewards, obses_tp1, dones, ref_indexs = [], [], [], [], [], []
for i in idxes:
data = self._storage[i]
obs_t, action, reward, obs_tp1, done, ref_index = data
obses_t.append(np.array(obs_t, copy=False))
actions.append(np.array(action, copy=False))
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
ref_indexs.append(ref_index)
return np.array(obses_t), np.array(actions), np.array(rewards), \
np.array(obses_tp1), np.array(dones), np.array(ref_indexs)
def sample_idxes(self, batch_size):
return np.array([random.randint(0, len(self._storage) - 1) for _ in range(batch_size)], dtype=np.int32)
def sample_with_idxes(self, idxes):
return list(self._encode_sample(idxes)) + [idxes,]
def sample(self, batch_size):
idxes = self.sample_idxes(batch_size)
return self.sample_with_idxes(idxes)
def add_batch(self, batch):
for trans in batch:
self.add(*trans, 0)
def replay(self):
if len(self._storage) < self.replay_starts:
return None
if self.buffer_id == 1 and self.replay_times % self.args.buffer_log_interval == 0:
logger.info('Buffer info: {}'.format(self.get_stats()))
self.replay_times += 1
return self.sample(self.replay_batch_size)