Skip to content

Commit

Permalink
Merge pull request #1133 from bghira/feature/sd3-model-card-details
Browse files Browse the repository at this point in the history
revamp model card to work by default and provide quanto hints
  • Loading branch information
bghira authored Nov 10, 2024
2 parents 3293fa0 + 49eb37a commit c2701f6
Show file tree
Hide file tree
Showing 13 changed files with 237 additions and 63 deletions.
7 changes: 5 additions & 2 deletions documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ The following values are recommended for `config.json`:
"--validation_guidance_skip_layers_start": 0.01,
"--validation_guidance_skip_layers_stop": 0.2,
"--validation_guidance_skip_scale": 2.8,
"--validation_guidance": 4.0
"--validation_guidance": 4.0,
"--flux_use_uniform_schedule": true,
"--flux_schedule_auto_shift": true
}
```

Expand Down Expand Up @@ -308,7 +310,8 @@ Some changes were made to SimpleTuner's SD3.5 support:
- Offering a switch (`--sd3_clip_uncond_behaviour` and `--sd3_t5_uncond_behaviour`) to use empty encoded blank captions for unconditional predictions (`empty_string`, **default**) or zeros (`zero`), not a recommended setting to tweak.
- SD3.5 training loss function was updated to match that found in the upstream StabilityAI/SD3.5 repository
- Updated default `--flux_schedule_shift` value to 3 to match the static 1024px value for SD3
- 512px training requires the use of `--flux_schedule_shift=1`
- StabilityAI followed-up with documentation to use `--flux_schedule_shift=1` with `--flux_use_uniform_schedule`
- Community members have reported that `--flux_schedule_auto_shift` works better when using mult-aspect or multi-resolution training
- Updated the hard-coded tokeniser sequence length limit to **256** with the option to revert it to **77** tokens to save disk space or compute at the cost of output quality degradation


Expand Down
16 changes: 10 additions & 6 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,13 @@ def batch_write_embeddings(self):
if len(batch) > 0:
self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug(f"Exiting batch write thread, no more work to do after writing {written_elements} elements")
logger.debug(
f"Exiting batch write thread, no more work to do after writing {written_elements} elements"
)
break
logger.debug(f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}")
logger.debug(
f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}"
)
pass
except Exception:
logger.exception("An error occurred while writing embeddings to disk.")
Expand Down Expand Up @@ -525,9 +529,7 @@ def encode_prompt(self, prompt: str, is_validation: bool = False):
prompt,
is_validation,
zero_padding_tokens=(
True
if StateTracker.get_args().t5_padding == "zero"
else False
True if StateTracker.get_args().t5_padding == "zero" else False
),
)
else:
Expand Down Expand Up @@ -1320,7 +1322,9 @@ def compute_embeddings_for_sd3_prompts(
)
if should_encode:
# If load_from_cache is True, should_encode would be False unless we failed to load.
self.debug_log(f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}")
self.debug_log(
f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}"
)
prompt_embeds, pooled_prompt_embeds = self.encode_sd3_prompt(
self.text_encoders,
self.tokenizers,
Expand Down
8 changes: 5 additions & 3 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def info_log(message):
logger.info(message)


def check_column_values(column_data, column_name, parquet_path, fallback_caption_column=False):
def check_column_values(
column_data, column_name, parquet_path, fallback_caption_column=False
):
# Determine if the column contains arrays or scalar values
non_null_values = column_data.dropna()
if non_null_values.empty:
Expand Down Expand Up @@ -362,15 +364,15 @@ def configure_parquet_database(backend: dict, args, data_backend: BaseDataBacken
df[caption_column],
caption_column,
parquet_path,
fallback_caption_column=fallback_caption_column
fallback_caption_column=fallback_caption_column,
)

# Apply the function to the filename_column.
check_column_values(
df[filename_column],
filename_column,
parquet_path,
fallback_caption_column=False # Always check filename_column
fallback_caption_column=False, # Always check filename_column
)

# Store the database in StateTracker
Expand Down
5 changes: 4 additions & 1 deletion helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def __init__(
self.resolution = self.data_backend_config.get("resolution")
self.resolution_type = self.data_backend_config.get("resolution_type")
self.target_size_calculator = resize_helpers.get(self.resolution_type)
if self.target_size_calculator is None and conditioning_type not in ["mask", "controlnet"]:
if self.target_size_calculator is None and conditioning_type not in [
"mask",
"controlnet",
]:
raise ValueError(f"Unknown resolution type: {self.resolution_type}")
self._set_resolution()
self.target_downsample_size = self.data_backend_config.get(
Expand Down
12 changes: 6 additions & 6 deletions helpers/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,16 +824,16 @@ def __call__(

noise_pred = self.transformer(
hidden_states=latents.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=text_ids,
img_ids=latent_image_ids,
Expand All @@ -846,16 +846,16 @@ def __call__(
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
noise_pred_uncond = self.transformer(
hidden_states=latents.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=negative_prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=negative_text_ids.to(device=self.transformer.device),
img_ids=latent_image_ids.to(device=self.transformer.device),
Expand Down
4 changes: 3 additions & 1 deletion helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def prepare_instance_prompt_from_parquet(
if type(image_caption) == str:
image_caption = image_caption.strip()
if type(image_caption) in (list, tuple, numpy.ndarray, pd.Series):
image_caption = [str(item).strip() for item in image_caption if item is not None]
image_caption = [
str(item).strip() for item in image_caption if item is not None
]
if prepend_instance_prompt:
if type(image_caption) == list:
image_caption = [instance_prompt + " " + x for x in image_caption]
Expand Down
82 changes: 74 additions & 8 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,38 @@ def _model_imports(args):
return f"{output}"


def lycoris_download_info():
"""output a function to download the adapter"""
output_fn = """
def download_adapter(repo_id: str):
import os
from huggingface_hub import hf_hub_download
adapter_filename = "pytorch_lora_weights.safetensors"
cache_dir = os.environ.get('HF_PATH', os.path.expanduser('~/.cache/huggingface/hub/models'))
cleaned_adapter_path = repo_id.replace("/", "_").replace("\\\\", "_").replace(":", "_")
path_to_adapter = os.path.join(cache_dir, cleaned_adapter_path)
path_to_adapter_file = os.path.join(path_to_adapter, adapter_filename)
os.makedirs(path_to_adapter, exist_ok=True)
hf_hub_download(
repo_id=repo_id, filename=adapter_filename, local_dir=path_to_adapter
)
return path_to_adapter_file
"""

return output_fn


def _model_component_name(args):
model_component_name = "pipeline.transformer"
if args.model_family in ["sdxl", "kolors", "legacy", "deepfloyd"]:
model_component_name = "pipeline.unet"

return model_component_name


def _model_load(args, repo_id: str = None):
model_component_name = _model_component_name(args)
hf_user_name = StateTracker.get_hf_username()
if hf_user_name is not None:
repo_id = f"{hf_user_name}/{repo_id}" if hf_user_name else repo_id
Expand All @@ -114,22 +145,26 @@ def _model_load(args, repo_id: str = None):
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = '{repo_id if repo_id is not None else args.output_dir}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id)"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id), torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
f"\npipeline.load_lora_weights(adapter_id)"
)
elif args.lora_type.lower() == "lycoris":
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = 'pytorch_lora_weights.safetensors' # you will have to download this manually"
f"{lycoris_download_info()}"
f"\nmodel_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_repo_id = '{repo_id if repo_id is not None else args.output_dir}'"
f"\nadapter_filename = 'pytorch_lora_weights.safetensors'"
f"\nadapter_file_path = download_adapter(repo_id=adapter_repo_id)"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
"\nlora_scale = 1.0"
)
else:
output = (
f"model_id = '{repo_id if repo_id else os.path.join(args.output_dir, 'pipeline')}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id)"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
)
if args.model_type == "lora" and args.lora_type.lower() == "lycoris":
output += f"\nwrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipeline.transformer)"
output += f"\nwrapper, _ = create_lycoris_from_weights(lora_scale, adapter_file_path, {model_component_name})"
output += "\nwrapper.merge_to()"

return output
Expand Down Expand Up @@ -162,6 +197,33 @@ def _skip_layers(args):
return f"\n skip_guidance_layers={args.validation_guidance_skip_layers},"


def _pipeline_move_to(args):
output = f"pipeline.to({_torch_device()}) # the pipeline is already in its target precision level"

return output


def _pipeline_quanto(args):
# return some optional lines to run Quanto on the model pipeline
if args.model_type == "full":
return ""
model_component_name = _model_component_name(args)
comment_character = ""
was_quantised = "The model was quantised during training, and so it is recommended to do the same during inference time."
if args.base_model_precision == "no_change":
comment_character = "#"
was_quantised = "The model was not quantised during training, so it is not necessary to quantise it during inference time."
output = f"""
## Optional: quantise the model to save on vram.
## Note: {was_quantised}
{comment_character}from optimum.quanto import quantize, freeze, qint8
{comment_character}quantize({model_component_name}, weights=qint8)
{comment_character}freeze({model_component_name})
"""

return output


def _validation_resolution(args):
if args.validation_resolution == "" or args.validation_resolution is None:
return f"width=1024,\n" f" height=1024,"
Expand All @@ -188,13 +250,14 @@ def code_example(args, repo_id: str = None):
prompt = "{args.validation_prompt if args.validation_prompt else 'An astronaut is riding a horse through the jungles of Thailand.'}"
{_negative_prompt(args)}
pipeline.to({_torch_device()})
{_pipeline_quanto(args)}
{_pipeline_move_to(args)}
image = pipeline(
prompt=prompt,{_negative_prompt(args, in_call=True) if args.model_family.lower() != 'flux' else ''}
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
{_validation_resolution(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)},{_skip_layers(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)}{_skip_layers(args)}
).images[0]
image.save("output.png", format="PNG")
```
Expand Down Expand Up @@ -226,7 +289,10 @@ def lora_info(args):
lycoris_config_file = args.lycoris_config
# read the json file
with open(lycoris_config_file, "r") as file:
lycoris_config = json.load(file)
try:
lycoris_config = json.load(file)
except:
lycoris_config = {"error": "could not locate or load LyCORIS config."}
return f"""- LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""


Expand Down
3 changes: 2 additions & 1 deletion helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def safety_check(args, accelerator):
sys.exit(1)

if (
args.flux_schedule_shift is not None and args.flux_schedule_shift > 0
args.flux_schedule_shift is not None
and args.flux_schedule_shift > 0
and args.flux_schedule_auto_shift
):
logger.error(
Expand Down
12 changes: 8 additions & 4 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,9 +2183,11 @@ def train(self):
)
training_logger.debug(f"Working on batch size: {bsz}")
if self.config.flow_matching:
if (
not self.config.flux_fast_schedule
and not any([self.config.flux_use_beta_schedule, self.config.flux_use_uniform_schedule])
if not self.config.flux_fast_schedule and not any(
[
self.config.flux_use_beta_schedule,
self.config.flux_use_uniform_schedule,
]
):
# imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF
# also used by: https://github.com/XLabs-AI/x-flux/tree/main
Expand Down Expand Up @@ -2316,7 +2318,9 @@ def train(self):
elif self.config.flow_matching_loss == "compatible":
target = noise - latents
elif self.config.flow_matching_loss == "sd35":
sigma_reshaped = sigmas.view(-1, 1, 1, 1) # Ensure sigma has the correct shape
sigma_reshaped = sigmas.view(
-1, 1, 1, 1
) # Ensure sigma has the correct shape
target = (noisy_latents - latents) / sigma_reshaped

elif self.noise_scheduler.config.prediction_type == "epsilon":
Expand Down
12 changes: 9 additions & 3 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,9 @@ def _gather_prompt_embeds(self, validation_prompt: str):
device=self.inference_device, dtype=self.weight_dtype
)
)
prompt_embeds["pooled_prompt_embeds"] = current_validation_pooled_embeds
prompt_embeds["pooled_prompt_embeds"] = current_validation_pooled_embeds.to(
device=self.inference_device, dtype=self.weight_dtype
)
prompt_embeds["negative_pooled_prompt_embeds"] = (
self.validation_negative_pooled_embeds
)
Expand All @@ -662,7 +664,9 @@ def _gather_prompt_embeds(self, validation_prompt: str):
current_validation_prompt_embeds, current_validation_prompt_mask = (
current_validation_prompt_embeds
)
current_validation_prompt_embeds = current_validation_prompt_embeds[0]
current_validation_prompt_embeds = current_validation_prompt_embeds[
0
].to(device=self.inference_device, dtype=self.weight_dtype)
if (
type(self.validation_negative_prompt_embeds) is tuple
or type(self.validation_negative_prompt_embeds) is list
Expand All @@ -672,7 +676,9 @@ def _gather_prompt_embeds(self, validation_prompt: str):
self.validation_negative_prompt_mask,
) = self.validation_negative_prompt_embeds[0]
else:
current_validation_prompt_embeds = current_validation_prompt_embeds[0]
current_validation_prompt_embeds = current_validation_prompt_embeds[
0
].to(device=self.inference_device, dtype=self.weight_dtype)
# logger.debug(
# f"Validations received the prompt embed: ({type(current_validation_prompt_embeds)}) positive={current_validation_prompt_embeds.shape if type(current_validation_prompt_embeds) is not list else current_validation_prompt_embeds[0].shape},"
# f" ({type(self.validation_negative_prompt_embeds)}) negative={self.validation_negative_prompt_embeds.shape if type(self.validation_negative_prompt_embeds) is not list else self.validation_negative_prompt_embeds[0].shape}"
Expand Down
Loading

0 comments on commit c2701f6

Please sign in to comment.