You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current version of DCRL_emitter is using duplicated sampling of transitions from replay buffer. The function "emit_pg" calls a vmapped version of "mutation_function_pg" which samples transitions from the replay buffer. This means the transitions for policy-gradient training are duplicated for each offspring (sampling is made with the same random key: emitter_state.key). Hence, we generate (dcrl_batch_size x num_pg_training_steps x batch_size) of transitions in each step.
With the default configuration, we will have 64 x 150 x 256 = 2,457,600 transitions (more than twice of the replay-buffer). In humanoid, this will occupy extra 4~7 GiB of VRAM (depending on the version of Brax) for no reason.
The major issue is with the "policy_loss_fn" generated by "make_td3_loss_dc_fn", which guides the mutation of each policy with descriptor-conditioned critic. However, in policy_loss_fn, the intended descriptor is taken from the transitions (transition.desc_prime), meaning that we will need to assign a unique set of transitions to each policy if we want to guide them towards different BDs.
A quick fix is to redefine the "policy_loss_fn" by adding an extra argument "desc" so that the transitions can be shared across different policies, greatly reducing the memory cost.
The text was updated successfully, but these errors were encountered:
The current version of DCRL_emitter is using duplicated sampling of transitions from replay buffer. The function "emit_pg" calls a vmapped version of "mutation_function_pg" which samples transitions from the replay buffer. This means the transitions for policy-gradient training are duplicated for each offspring (sampling is made with the same random key: emitter_state.key). Hence, we generate (dcrl_batch_size x num_pg_training_steps x batch_size) of transitions in each step.
With the default configuration, we will have 64 x 150 x 256 = 2,457,600 transitions (more than twice of the replay-buffer). In humanoid, this will occupy extra 4~7 GiB of VRAM (depending on the version of Brax) for no reason.
The major issue is with the "policy_loss_fn" generated by "make_td3_loss_dc_fn", which guides the mutation of each policy with descriptor-conditioned critic. However, in policy_loss_fn, the intended descriptor is taken from the transitions (transition.desc_prime), meaning that we will need to assign a unique set of transitions to each policy if we want to guide them towards different BDs.
A quick fix is to redefine the "policy_loss_fn" by adding an extra argument "desc" so that the transitions can be shared across different policies, greatly reducing the memory cost.
The text was updated successfully, but these errors were encountered: