-
Notifications
You must be signed in to change notification settings - Fork 0
/
buffers.py
201 lines (170 loc) · 6.48 KB
/
buffers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import jax
import jax.numpy as jnp
import flax.struct
from qdax.core.neuroevolution.buffers import buffer
import numpy as np
import logging
import time
from functools import partial
from collections.abc import Callable
from typing import Self, TYPE_CHECKING
from ..utils import RNGKey, jax_jit, onp_callback, jax_pure_callback
if TYPE_CHECKING:
from ..tasks import RLTask
_log = logging.getLogger(__name__)
global_buffer_data: list[np.ndarray] = []
global_restore_fn: list[Callable[[np.ndarray], np.ndarray]] = []
global_time_restore: float = 0.0
global_time_insert: float = 0.0
global_time_sample: float = 0.0
@onp_callback
def _update_global_buffer_data(
fake_data: np.ndarray,
flattened_transitions: np.ndarray,
roll: np.ndarray,
new_position: np.ndarray,
):
global global_buffer_data
global global_time_restore
global global_time_insert
global_time_insert -= time.monotonic()
global_idx = int(fake_data.flatten()[0])
replace_all = flattened_transitions.shape[-2] == global_buffer_data[global_idx].shape[-2]
if replace_all:
global_buffer_data[global_idx] = None # type: ignore
global_time_restore -= time.monotonic()
flattened_transitions = global_restore_fn[global_idx](flattened_transitions)
global_time_restore += time.monotonic()
if replace_all:
global_buffer_data[global_idx] = flattened_transitions
else:
global_buffer_data[global_idx] = _do_update_global_buffer_data(
global_buffer_data[global_idx],
flattened_transitions,
roll.flatten()[0],
new_position.flatten()[0],
)
global_time_insert += time.monotonic()
return fake_data
def _do_update_global_buffer_data(
buffer_data: np.ndarray,
flattened_transitions: np.ndarray,
roll: np.ndarray,
new_position: np.ndarray,
):
if roll != 0:
_log.warning('Rolling...')
buffer_data = np.roll(buffer_data, roll, axis=-2)
buffer_data[
..., new_position:new_position + flattened_transitions.shape[-2], :
] = flattened_transitions
return buffer_data
@onp_callback
def _take_from_global_buffer(data: np.ndarray, idx: np.ndarray):
global global_time_sample
global_time_sample -= time.monotonic()
global_idx = int(data.flatten()[0])
taken = np.vectorize(np.take, excluded=('axis', 'mode'), signature='(i,j),(k)->(k,j)')(
global_buffer_data[global_idx], idx, axis=-2, mode='clip'
)
global_time_sample += time.monotonic()
return taken
class CPUReplayBuffer(buffer.ReplayBuffer):
flatten_dim: int = flax.struct.field(pytree_node=False)
@classmethod
def init( # pyright: ignore [reportIncompatibleMethodOverride]
cls,
buffer_size: int,
transition: buffer.Transition,
rand: jax.Array,
task: 'RLTask',
) -> Self:
data = jnp.zeros((1, 1), dtype=jnp.float32)
@onp_callback
def onp_fn(data: np.ndarray, rand: np.ndarray):
global global_buffer_data
for size in reversed(rand.shape):
data = np.repeat(np.expand_dims(data, axis=0), repeats=size, axis=0)
shape = data.shape
global_idx = len(global_buffer_data)
_log.info(f'init.onp_fn: {global_idx}')
data = np.zeros(
(*data.shape[:-2], buffer_size, transition.flatten_dim), dtype=np.float32
)
global_buffer_data.append(data)
global_restore_fn.append(task.onp_restore_transitions)
return np.full(shape, global_idx, dtype=np.float32)
data = jax_pure_callback(onp_fn, data, data, rand, vectorized=True)
current_size = jnp.array(0, dtype=int)
current_position = jnp.array(0, dtype=int)
return cls(
data=data,
current_size=current_size,
current_position=current_position,
buffer_size=buffer_size,
flatten_dim=transition.flatten_dim,
transition=transition,
)
@partial(jax_jit, static_argnames=('sample_size',))
def sample( # pyright: ignore [reportIncompatibleVariableOverride]
self,
random_key: RNGKey,
sample_size: int,
) -> tuple[buffer.Transition, RNGKey]:
random_key, subkey = jax.random.split(random_key)
idx = jax.random.randint(
subkey,
shape=(sample_size,),
minval=0,
maxval=self.current_size,
)
samples = jax_pure_callback(
_take_from_global_buffer,
jax.ShapeDtypeStruct((sample_size, self.flatten_dim), dtype=jnp.float32),
self.data,
idx,
vectorized=True,
)
assert isinstance(samples, jax.Array)
transitions = self.transition.__class__.from_flatten(samples, self.transition)
return transitions, random_key
@jax_jit
def insert( # pyright: ignore [reportIncompatibleVariableOverride]
self, transitions: buffer.Transition
) -> Self:
flattened_transitions = transitions.flatten()
flattened_transitions = flattened_transitions.reshape(
(-1, flattened_transitions.shape[-1])
)
num_transitions = flattened_transitions.shape[0]
max_replay_size = self.buffer_size
# Make sure update is not larger than the maximum replay size.
if num_transitions > max_replay_size:
raise ValueError(
'Trying to insert a batch of samples larger than the maximum replay '
f'size. num_samples: {num_transitions}, '
f'max replay size {max_replay_size}'
)
# get current position
position = self.current_position
# check if there is an overlap
roll = jnp.minimum(0, max_replay_size - position - num_transitions)
# update the position accordingly
new_position = position + roll
# replace old data by the new one
new_data = jax_pure_callback(
_update_global_buffer_data,
self.data,
self.data, flattened_transitions, roll, new_position,
vectorized=True
)
# update the position and the size
new_position = (new_position + num_transitions) % max_replay_size
new_size = jnp.minimum(self.current_size + num_transitions, max_replay_size)
# update the replay buffer
replay_buffer = self.replace(
current_position=new_position,
current_size=new_size,
data=new_data,
)
return replay_buffer