Skip to content

Commit

Permalink
Merge pull request #1152 from bghira/feature/model-card-parameters
Browse files Browse the repository at this point in the history
metadata: add more ddpm related schedule info to the model card
  • Loading branch information
bghira authored Nov 13, 2024
2 parents 7aecc52 + ca023b5 commit 23ec487
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,39 @@ def sd3_schedule_info(args):

return output_str

def ddpm_schedule_info(args):
"""Information about DDPM schedules, eg. rescaled betas or offset noise"""
output_args = []
if args.snr_gamma:
output_args.append(f"snr_gamma={args.snr_gamma}")
if args.use_soft_min_snr:
output_args.append(f"use_soft_min_snr")
if args.soft_min_snr_sigma_data:
output_args.append(f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}")
if args.rescale_betas_zero_snr:
output_args.append(f"rescale_betas_zero_snr")
if args.offset_noise:
output_args.append(f"offset_noise")
output_args.append(f"noise_offset={args.noise_offset}")
output_args.append(f"noise_offset_probability={args.noise_offset_probability}")
output_args.append(f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}")
output_args.append(f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}")
output_str = (
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str

def model_schedule_info(args):
if args.model_family == "flux":
return flux_schedule_info(args)
if args.model_family == "sd3":
return sd3_schedule_info(args)
else:
return ddpm_schedule_info(args)



def save_model_card(
Expand Down Expand Up @@ -495,10 +522,9 @@ def save_model_card(
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Quantised: {f'Yes: {StateTracker.get_args().base_model_precision}' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Quantised base model: {f'Yes ({StateTracker.get_args().base_model_precision})' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'}
{lora_info(args=StateTracker.get_args())}
Expand Down

0 comments on commit 23ec487

Please sign in to comment.