From 8c2a8c84e708e9872a7c4c5348c6a8173e982b60 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 26 Aug 2024 09:18:44 -0700 Subject: [PATCH] [pmap no rank reduce cleanup]: Prepare for flipping the jax_pmap_no_rank_reduction flag. This flag slows down utils.get_from_first_device which must perform rank reduction for each step which makes this test timeout. This helper was only needed because the trainstate was not properly marked as replicated, so we can just do that instead and avoid the timeout. PiperOrigin-RevId: 667607592 Change-Id: If0a1756d38f4e1741495e8ed9d1e7cc2bb3695d3 --- acme/agents/jax/bc/agent_test.py | 5 ++-- acme/agents/jax/bc/learning.py | 44 +++++++++++++++++++----------- acme/agents/jax/mbop/agent_test.py | 3 +- acme/agents/jax/mpo/builder.py | 4 +-- 4 files changed, 32 insertions(+), 24 deletions(-) diff --git a/acme/agents/jax/bc/agent_test.py b/acme/agents/jax/bc/agent_test.py index e266b49a02..67e3e4cb9f 100644 --- a/acme/agents/jax/bc/agent_test.py +++ b/acme/agents/jax/bc/agent_test.py @@ -21,7 +21,6 @@ from acme.jax import types as jax_types from acme.jax import utils from acme.testing import fakes -import chex import haiku as hk import jax import jax.numpy as jnp @@ -103,7 +102,7 @@ class BCTest(parameterized.TestCase): ('peerbc',) ) def test_continuous_actions(self, loss_name): - with chex.fake_pmap_and_jit(): + with jax.disable_jit(): num_sgd_steps_per_step = 1 num_steps = 5 @@ -145,7 +144,7 @@ def test_continuous_actions(self, loss_name): ('logp',), ('rcal',)) def test_discrete_actions(self, loss_name): - with chex.fake_pmap_and_jit(): + with jax.disable_jit(): num_sgd_steps_per_step = 1 num_steps = 5 diff --git a/acme/agents/jax/bc/learning.py b/acme/agents/jax/bc/learning.py index 46eb2607a9..edc9ac48e7 100644 --- a/acme/agents/jax/bc/learning.py +++ b/acme/agents/jax/bc/learning.py @@ -150,20 +150,32 @@ def sgd_step( # Split the input batch to `num_sgd_steps_per_step` minibatches in order # to achieve better performance on accelerators. sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step) - self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) - - random_key, init_key = jax.random.split(random_key) - policy_params = networks.policy_network.init(init_key) - optimizer_state = optimizer.init(policy_params) - - # Create initial state. - state = TrainingState( - optimizer_state=optimizer_state, - policy_params=policy_params, - key=random_key, - steps=0, + self._sgd_step = jax.pmap( + sgd_step, + axis_name=_PMAP_AXIS_NAME, + in_axes=(None, 0), + out_axes=(None, 0), ) - self._state = utils.replicate_in_all_devices(state) + + def init_fn(random_key): + random_key, init_key = jax.random.split(random_key) + policy_params = networks.policy_network.init(init_key) + optimizer_state = optimizer.init(policy_params) + + # Create initial state. + state = TrainingState( + optimizer_state=optimizer_state, + policy_params=policy_params, + key=random_key, + steps=0, + ) + return state + + state = jax.pmap(init_fn, out_axes=None)( + utils.replicate_in_all_devices(random_key) + ) + self._state = state + self._state_sharding = jax.tree.map(lambda x: x.sharding, state) self._timestamp = None @@ -188,13 +200,13 @@ def step(self): def get_variables(self, names: List[str]) -> List[networks_lib.Params]: variables = { - 'policy': utils.get_from_first_device(self._state.policy_params), + 'policy': self._state.policy_params, } return [variables[name] for name in names] def save(self) -> TrainingState: # Serialize only the first replica of parameters and optimizer state. - return jax.tree.map(utils.get_from_first_device, self._state) + return self._state def restore(self, state: TrainingState): - self._state = utils.replicate_in_all_devices(state) + self._state = jax.device_put(state, self._state_sharding) diff --git a/acme/agents/jax/mbop/agent_test.py b/acme/agents/jax/mbop/agent_test.py index db0fcadc3e..54124e4419 100644 --- a/acme/agents/jax/mbop/agent_test.py +++ b/acme/agents/jax/mbop/agent_test.py @@ -23,7 +23,6 @@ from acme.agents.jax.mbop import networks as mbop_networks from acme.testing import fakes from acme.utils import loggers -import chex import jax import optax import rlds @@ -34,7 +33,7 @@ class MBOPTest(absltest.TestCase): def test_learner(self): - with chex.fake_pmap_and_jit(): + with jax.disable_jit(): num_sgd_steps_per_step = 1 num_steps = 5 num_networks = 7 diff --git a/acme/agents/jax/mpo/builder.py b/acme/agents/jax/mpo/builder.py index 1f6f30df82..2ce7816d32 100644 --- a/acme/agents/jax/mpo/builder.py +++ b/acme/agents/jax/mpo/builder.py @@ -38,7 +38,6 @@ from acme.jax import variable_utils from acme.utils import counting from acme.utils import loggers -import chex import jax import optax import reverb @@ -162,8 +161,7 @@ def make_learner(self, 'learner', steps_key=counter.get_steps_key() if counter else 'learner_steps') - with chex.fake_pmap_and_jit(not self.config.jit_learner, - not self.config.jit_learner): + with jax.disable_jit(not self.config.jit_learner): learner = learning.MPOLearner( iterator=dataset, networks=networks,