diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 7c244ae35..b3e9b9b50 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -422,6 +422,12 @@ def main(args): original_model = model if args.torchcompile: logging.info('Compiling model...') + + if args.grad_checkpointing and args.distributed: + logging.info('Disabling DDP dynamo optimizer when grad checkpointing enabled.') + # As of now (~PyTorch 2.4/2.5), compile + checkpointing but DDP optimizer must be disabled + torch._dynamo.config.optimize_ddp = False + model = torch.compile(original_model) if 'train' not in data: