Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient checkpointing speed-up #1184

Merged
merged 5 commits into from
Dec 3, 2024
Merged

Conversation

bghira
Copy link
Owner

@bghira bghira commented Dec 3, 2024

thanks to snek on discord for the heads-up to Erwann Millon's post on a quick win for gradient checkpointing.

while investigating i found that inference was running w/ checkpointing even though we're not calculating grads so that's been disabled before validations run too.

flux

generation time from 29s to 14s for a 28 step img on flux dev on a 4090 with a 5800X3D and fp8-quanto

when training with a rank 1 lora we can use --gradient_checkpointing_interval=2 on a 24G 4090 w/ fp8-quanto and speed up from 4.25 second per training step to 3.00 second per training step at 1024px

otherwise you'd be able to train much larger LoRA on 24G.

haven't tested this on larger GPUs yet, where the benefits are greater and allows higher intervals.

sdxl

implemented via a hackjob from hell, thanks to the o1-preview LLM and the tendril-like complexity of the forward block checkpointing in the diffusers Unet code. it's a big sledgehammer that overwrites the torch checkpointing function so that it only actually runs every n calls.

speeds up training from 1.00 it/sec to 1.50 it/sec using the same dataset from flux training test on a LoRA rank of 16.

it's easy to OOM here with an interval that's too high, and in the case of SDXL it'll possibly happen randomly after your first validations run. so, be mindful of that.

@bghira bghira merged commit cd0644d into main Dec 3, 2024
1 check passed
@bghira bghira deleted the feature/gradient-checkpointing-speedup branch December 3, 2024 22:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant