From ca023b59ec84cfdb54f704726164e12236def7d6 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 13 Nov 2024 19:04:36 +0000 Subject: [PATCH] metadata: add more ddpm related schedule info to the model card --- helpers/publishing/metadata.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index b3cd0568..4aeccbc5 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -362,12 +362,39 @@ def sd3_schedule_info(args): return output_str +def ddpm_schedule_info(args): + """Information about DDPM schedules, eg. rescaled betas or offset noise""" + output_args = [] + if args.snr_gamma: + output_args.append(f"snr_gamma={args.snr_gamma}") + if args.use_soft_min_snr: + output_args.append(f"use_soft_min_snr") + if args.soft_min_snr_sigma_data: + output_args.append(f"soft_min_snr_sigma_data={args.soft_min_snr_sigma_data}") + if args.rescale_betas_zero_snr: + output_args.append(f"rescale_betas_zero_snr") + if args.offset_noise: + output_args.append(f"offset_noise") + output_args.append(f"noise_offset={args.noise_offset}") + output_args.append(f"noise_offset_probability={args.noise_offset_probability}") + output_args.append(f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}") + output_args.append(f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}") + output_str = ( + f" (extra parameters={output_args})" + if output_args + else " (no special parameters set)" + ) + + return output_str def model_schedule_info(args): if args.model_family == "flux": return flux_schedule_info(args) if args.model_family == "sd3": return sd3_schedule_info(args) + else: + return ddpm_schedule_info(args) + def save_model_card( @@ -495,10 +522,9 @@ def save_model_card( - Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps} - Number of GPUs: {StateTracker.get_accelerator().num_processes} - Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())} -- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr} - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} -- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} -- Quantised: {f'Yes: {StateTracker.get_args().base_model_precision}' if StateTracker.get_args().base_model_precision != "no_change" else 'No'} +- Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} +- Quantised base model: {f'Yes ({StateTracker.get_args().base_model_precision})' if StateTracker.get_args().base_model_precision != "no_change" else 'No'} - Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'} {lora_info(args=StateTracker.get_args())}