Skip to content

Commit

Permalink
Merge pull request #1230 from bghira/feature/sd3-checkpointing-interval
Browse files Browse the repository at this point in the history
sd3: allow setting grad checkpointing interval
  • Loading branch information
bghira authored Dec 18, 2024
2 parents de9a7c3 + 18fe891 commit 7815be2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def get_argument_parser():
default=None,
type=int,
help=(
"Some models (Flux, SDXL, SD1.x/2.x) can have their gradient checkpointing limited to every nth block."
"Some models (Flux, SDXL, SD1.x/2.x, SD3) can have their gradient checkpointing limited to every nth block."
" This can speed up training but will use more memory with larger intervals."
),
)
Expand Down
19 changes: 18 additions & 1 deletion helpers/models/sd3/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ def __init__(
)

self.gradient_checkpointing = False
self.gradient_checkpointing_interval = None

def set_gradient_checkpointing_interval(self, interval: int):
"""
Sets the interval for gradient checkpointing.
Parameters:
interval (`int`): The interval for gradient checkpointing.
"""
self.gradient_checkpointing_interval = interval

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(
Expand Down Expand Up @@ -384,7 +394,14 @@ def forward(
)
continue

if self.training and self.gradient_checkpointing:
if (
self.training
and self.gradient_checkpointing
and (
self.gradient_checkpointing_interval is None
or index_block % self.gradient_checkpointing_interval == 0
)
):

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
6 changes: 1 addition & 5 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,7 @@ def safety_check(args, accelerator):
)
args.attention_mechanism = "diffusers"

gradient_checkpointing_interval_supported_models = [
"flux",
"sana",
"sdxl",
]
gradient_checkpointing_interval_supported_models = ["flux", "sana", "sdxl", "sd3"]
if args.gradient_checkpointing_interval is not None:
if (
args.model_family.lower()
Expand Down
5 changes: 4 additions & 1 deletion helpers/training/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def load_diffusion_model(args, weight_dtype):

set_checkpoint_interval(int(args.gradient_checkpointing_interval))

if args.gradient_checkpointing_interval is not None:
if (
args.gradient_checkpointing_interval is not None
and args.gradient_checkpointing_interval > 1
):
if transformer is not None and hasattr(
transformer, "set_gradient_checkpointing_interval"
):
Expand Down

0 comments on commit 7815be2

Please sign in to comment.