Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

duplicated sampling from replaybuffer in DCRL #206

Open
MartinMao2023 opened this issue Dec 2, 2024 · 0 comments
Open

duplicated sampling from replaybuffer in DCRL #206

MartinMao2023 opened this issue Dec 2, 2024 · 0 comments

Comments

@MartinMao2023
Copy link

The current version of DCRL_emitter is using duplicated sampling of transitions from replay buffer. The function "emit_pg" calls a vmapped version of "mutation_function_pg" which samples transitions from the replay buffer. This means the transitions for policy-gradient training are duplicated for each offspring (sampling is made with the same random key: emitter_state.key). Hence, we generate (dcrl_batch_size x num_pg_training_steps x batch_size) of transitions in each step.
With the default configuration, we will have 64 x 150 x 256 = 2,457,600 transitions (more than twice of the replay-buffer). In humanoid, this will occupy extra 4~7 GiB of VRAM (depending on the version of Brax) for no reason.
The major issue is with the "policy_loss_fn" generated by "make_td3_loss_dc_fn", which guides the mutation of each policy with descriptor-conditioned critic. However, in policy_loss_fn, the intended descriptor is taken from the transitions (transition.desc_prime), meaning that we will need to assign a unique set of transitions to each policy if we want to guide them towards different BDs.
A quick fix is to redefine the "policy_loss_fn" by adding an extra argument "desc" so that the transitions can be shared across different policies, greatly reducing the memory cost.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant