Skip to content

Commit

Permalink
Merge pull request #3 from octo-models/fix_finetune
Browse files Browse the repository at this point in the history
Raising useful error message in finetune.py for arbitrary # GPUs
  • Loading branch information
kvablack authored Dec 14, 2023
2 parents a0f965b + e595cd9 commit bf8c7f4
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def main(_):
#
#########

assert (
FLAGS.config.batch_size % len(devices) == 0
), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})"
assert (
FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0
), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})"

# create a 1D mesh with a single axis named "batch"
mesh = Mesh(jax.devices(), axis_names="batch")
# Our batches will be data-parallel sharded -- each device will get a slice of the batch
Expand Down

0 comments on commit bf8c7f4

Please sign in to comment.