Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
thanks to snek on discord for the heads-up to Erwann Millon's post on a quick win for gradient checkpointing.
while investigating i found that inference was running w/ checkpointing even though we're not calculating grads so that's been disabled before validations run too.
flux
generation time from 29s to 14s for a 28 step img on flux dev on a 4090 with a 5800X3D and fp8-quanto
when training with a rank 1 lora we can use
--gradient_checkpointing_interval=2
on a 24G 4090 w/ fp8-quanto and speed up from 4.25 second per training step to 3.00 second per training step at 1024pxotherwise you'd be able to train much larger LoRA on 24G.
haven't tested this on larger GPUs yet, where the benefits are greater and allows higher intervals.
sdxl
implemented via a hackjob from hell, thanks to the o1-preview LLM and the tendril-like complexity of the forward block checkpointing in the diffusers Unet code. it's a big sledgehammer that overwrites the torch checkpointing function so that it only actually runs every n calls.
speeds up training from 1.00 it/sec to 1.50 it/sec using the same dataset from flux training test on a LoRA rank of 16.
it's easy to OOM here with an interval that's too high, and in the case of SDXL it'll possibly happen randomly after your first validations run. so, be mindful of that.