Skip to content

Commit

Permalink
Fixing finetune.py for arbitrary # GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed Dec 14, 2023
1 parent c8cd9ee commit dd4a9d6
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def main(_):
#
#########

assert (
FLAGS.config.batch_size % len(devices) == 0
), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})"
assert (
FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0,
), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})"

# create a 1D mesh with a single axis named "batch"
mesh = Mesh(jax.devices(), axis_names="batch")
# Our batches will be data-parallel sharded -- each device will get a slice of the batch
Expand Down

0 comments on commit dd4a9d6

Please sign in to comment.