From dd4a9d67799127a503d452a842934067e1ec630e Mon Sep 17 00:00:00 2001 From: Dibya Ghosh Date: Wed, 13 Dec 2023 21:48:04 -0800 Subject: [PATCH] Fixing finetune.py for arbitrary # GPUs --- scripts/finetune.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scripts/finetune.py b/scripts/finetune.py index 463375ce..1d50c5c1 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -82,6 +82,13 @@ def main(_): # ######### + assert ( + FLAGS.config.batch_size % len(devices) == 0 + ), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})" + assert ( + FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0, + ), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})" + # create a 1D mesh with a single axis named "batch" mesh = Mesh(jax.devices(), axis_names="batch") # Our batches will be data-parallel sharded -- each device will get a slice of the batch