Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: omnigen initial training support attempt #1117

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
"full": [
"flux",
"sdxl",
"omnigen",
"pixart_sigma",
"kolors",
"sd3",
"legacy",
],
"lora": ["flux", "sdxl", "kolors", "sd3", "legacy"],
"lora": ["flux", "sdxl", "kolors", "sd3", "legacy", "omnigen"],
"controlnet": ["sdxl", "legacy"],
}

Expand All @@ -34,6 +35,7 @@
"kolors": "kwai-kolors/kolors-diffusers",
"terminus": "ptx0/terminus-xl-velocity-v2",
"sd3": "stabilityai/stable-diffusion-3.5-large",
"omnigen": "Shitao/OmniGen-v1",
"legacy": "stabilityai/stable-diffusion-2-1-base",
}

Expand All @@ -43,12 +45,14 @@
"pixart_sigma": 3.4,
"kolors": 5.0,
"terminus": 8.0,
"sd3": 5.0,
"omnigen": 3.0,
"sd3": 6.0,
}

model_labels = {
"sd3": "Stable Diffusion 3",
"flux": "FLUX",
"omnigen": "OmniGen",
"pixart_sigma": "PixArt Sigma",
"kolors": "Kwai Kolors",
"terminus": "Terminus",
Expand Down
212 changes: 193 additions & 19 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,14 @@ def discover_all_files(self):

def save_to_cache(self, filename, embeddings):
"""Add write requests to the queue instead of writing directly."""
if not self.batch_write_thread.is_alive():
logger.debug("Restarting background write thread.")
# Start the thread again.
self.process_write_batches = True
self.batch_write_thread = Thread(target=self.batch_write_embeddings)
self.batch_write_thread.start()
self.process_write_batches = True
self.write_queue.put((embeddings, filename))
logger.debug(
f"save_to_cache called for {filename}, write queue has {self.write_queue.qsize()} items, and the write thread's status: {self.batch_write_thread.is_alive()}"
)

def batch_write_embeddings(self):
"""Process write requests in batches."""
batch = []
written_elements = 0
while True:
try:
# Block until an item is available or timeout occurs
Expand All @@ -233,38 +226,30 @@ def batch_write_embeddings(self):
while (
not self.write_queue.empty() and len(batch) < self.write_batch_size
):
logger.debug("Retrieving more items from the queue.")
items = self.write_queue.get_nowait()
batch.append(items)
logger.debug(f"Batch now contains {len(batch)} items.")

self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug("Processed batch write.")
written_elements += len(batch)

except queue.Empty:
# Timeout occurred, no items were ready
if not self.process_write_batches:
if len(batch) > 0:
self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug(
f"Exiting batch write thread, no more work to do after writing {written_elements} elements"
)
logger.debug(f"Exiting batch write thread, no more work to do")
break
logger.debug(
f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}"
)
pass
except Exception:
logger.exception("An error occurred while writing embeddings to disk.")
logger.debug("Exiting background batch write thread.")

def process_write_batch(self, batch):
"""Write a batch of embeddings to the cache."""
logger.debug(f"Writing {len(batch)} items to disk")
logger.debug(f"Batch: {batch}")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(self.data_backend.torch_save, *args) for args in batch
Expand Down Expand Up @@ -292,8 +277,8 @@ def encode_flux_prompt(
text_encoders: List of text encoders.
tokenizers: List of tokenizers.
prompt: The prompt to encode.
num_images_per_prompt: The number of images to generate per prompt.
is_validation: Whether the prompt is for validation. No-op for SD3.
zero_padding_tokens: Whether to zero out padding tokens.

Returns:
Tuple of (prompt_embeds, pooled_prompt_embeds).
Expand Down Expand Up @@ -324,6 +309,31 @@ def encode_flux_prompt(

return prompt_embeds, pooled_prompt_embeds, time_ids, masks

def encode_omnigen_prompt(
self, text_encoders, tokenizers, prompt: str, is_validation: bool = False
):
"""
Encode a prompt for an OmniGen model.

Args:
text_encoders: List of text encoders.
tokenizers: List of tokenizers.
prompt: The prompt to encode.
is_validation: Whether the prompt is for validation. No-op for OmniGen.

Returns:
Dict of OmniGen inputs
"""
# it's not a text encoder, it's the MLLM preprocessor / tokeniser.
processed = text_encoders[0](
instructions=prompt,
# use_img_cfg=False,
# separate_cfg_input=False,
# use_input_image_size_as_output=False,
)

return processed

# Adapted from pipelines.StableDiffusion3Pipeline.encode_prompt
def encode_sd3_prompt(
self,
Expand Down Expand Up @@ -668,6 +678,12 @@ def compute_embeddings_for_prompts(
return_concat=return_concat,
load_from_cache=load_from_cache,
)
elif self.model_type == "omnigen":
output = self.compute_embeddings_for_omnigen_prompts(
raw_prompts,
return_concat=return_concat,
load_from_cache=load_from_cache,
)
else:
raise ValueError(
f"No such text encoding backend for model type '{self.model_type}'"
Expand Down Expand Up @@ -1041,6 +1057,164 @@ def compute_embeddings_for_legacy_prompts(
return prompt_embeds_all, attention_masks_all
return prompt_embeds_all

def compute_embeddings_for_omnigen_prompts(
self,
prompts: list = None,
return_concat: bool = True,
is_validation: bool = False,
load_from_cache: bool = True,
):
# print(f"Computing embeddings for Omnigen prompts")
# processed = self.text_encoders[0](
# prompts,
# )

# # processed looks like:
# # {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
# print(f"Processed: {processed.keys()}")

# return processed
processed_all = []
should_encode = not load_from_cache
args = StateTracker.get_args()
if should_encode:
local_caption_split = self.split_captions_between_processes(
prompts or self.prompts
)
else:
local_caption_split = prompts or self.prompts
if (
hasattr(args, "cache_clear_validation_prompts")
and args.cache_clear_validation_prompts
and is_validation
):
# If --cache_clear_validation_prompts was provided, we will forcibly overwrite them.
load_from_cache = False
should_encode = True

if self.webhook_handler is not None:
last_reported_index = 0
self.send_progress_update(
type="init_cache_text_embeds_started",
progress=int(0 // len(local_caption_split)),
total=len(local_caption_split),
current=0,
)
self.write_thread_bar = tqdm(
desc="Write embeds to disk",
leave=False,
ncols=125,
disable=return_concat,
total=len(local_caption_split),
position=get_rank(),
)
with torch.no_grad():
last_reported_index = 0
for prompt in tqdm(
local_caption_split,
desc="Processing prompts",
disable=return_concat,
miniters=50,
leave=False,
ncols=125,
position=get_rank() + self.accelerator.num_processes + 1,
):
filename = os.path.join(self.cache_dir, self.hash_prompt(prompt))
debug_msg = f"Processing file: {filename}, prompt: {prompt}"
prompt = PromptHandler.filter_caption(self.data_backend, prompt)
debug_msg = f"{debug_msg}\n -> filtered prompt: {prompt}"
if prompt is None:
logger.error(f"Filename {filename} does not have a caption.")
continue
logger.debug(debug_msg)
if return_concat and load_from_cache:
try:
# We attempt to load.
_processed = self.load_from_cache(filename)
logger.debug(f"Cached OmniGen inputs: {_processed}")
except Exception as e:
# We failed to load. Now encode the prompt.
logger.error(
f"Failed retrieving prompt from cache:"
f"\n-> prompt: {prompt}"
f"\n-> filename: {filename}"
f"\n-> error: {e}"
f"\n-> id: {self.id}, data_backend id: {self.data_backend.id}"
)
should_encode = True
raise Exception(
"Cache retrieval for text embed file failed. Ensure your dataloader config value for skip_file_discovery does not contain 'text', and that preserve_data_backend_cache is disabled or unset."
)
if should_encode:
# If load_from_cache is True, should_encode would be False unless we failed to load.
self.debug_log(f"Encoding prompt: {prompt}")
_processed = self.encode_omnigen_prompt(
self.text_encoders, self.tokenizers, [prompt], is_validation
)
logger.debug(f"OmniGen prompt embeds: {_processed}")
current_size = self.write_queue.qsize()
if current_size >= 2048:
log_msg = str(
f"[WARNING] Write queue size is {current_size}. This is quite large."
" Consider increasing the write batch size. Delaying encode so that writes can catch up."
)
self.write_thread_bar.write(log_msg)
while self.write_queue.qsize() > 100:
time.sleep(0.1)

self.debug_log(f"Adding embed to write queue: {filename}")
self.save_to_cache(filename, _processed)
if (
self.webhook_handler is not None
and int(
self.write_thread_bar.n % self.webhook_progress_interval
)
< 10
):
last_reported_index = int(
self.write_thread_bar.n % self.webhook_progress_interval
)
self.send_progress_update(
type="init_cache_text_embeds_status_update",
progress=int(
self.write_thread_bar.n
// len(local_caption_split)
* 100
),
total=len(local_caption_split),
current=0,
)

if not return_concat:
del _processed
continue

if return_concat:
processed_all.append(_processed)

while self.write_queue.qsize() > 0:
time.sleep(0.1) # Sleep briefly to avoid busy-waiting

if self.webhook_handler is not None:
self.send_progress_update(
type="init_cache_text_embeds_status_complete",
progress=100,
total=len(local_caption_split),
current=len(local_caption_split),
)

# Close the tqdm progress bar after the loop
self.write_thread_bar.close()
self.process_write_batches = False

if not return_concat:
del processed_all
return

logger.debug(f"Returning all prompt embeds: {processed_all}")

return processed_all

def compute_embeddings_for_flux_prompts(
self,
prompts: list = None,
Expand Down Expand Up @@ -1337,7 +1511,7 @@ def compute_embeddings_for_sd3_prompts(
),
)
logger.debug(
f"Filename {filename} SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
f"SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
)
add_text_embeds = pooled_prompt_embeds
# StabilityAI say not to zero them out.
Expand Down
14 changes: 11 additions & 3 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,16 @@ def get_argument_parser():
)
parser.add_argument(
"--model_family",
choices=["pixart_sigma", "kolors", "sd3", "flux", "smoldit", "sdxl", "legacy"],
choices=[
"omnigen",
"pixart_sigma",
"kolors",
"sd3",
"flux",
"smoldit",
"sdxl",
"legacy",
],
default=None,
required=True,
help=("The model family to train. This option is required."),
Expand Down Expand Up @@ -2169,7 +2178,7 @@ def parse_cmdline_args(input_args=None):

if (
args.pretrained_vae_model_name_or_path is not None
and args.model_family in ["legacy", "flux", "sd3"]
and args.model_family in ["legacy", "flux", "sd3", "omnigen"]
and "sdxl" in args.pretrained_vae_model_name_or_path
and "deepfloyd" not in args.model_type
):
Expand Down Expand Up @@ -2199,7 +2208,6 @@ def parse_cmdline_args(input_args=None):
info_log(
f"SD3 embeds for unconditional captions: t5={args.sd3_t5_uncond_behaviour}, clip={args.sd3_clip_uncond_behaviour}"
)

elif "deepfloyd" in args.model_type:
deepfloyd_pixel_alignment = 8
if args.aspect_bucket_alignment != deepfloyd_pixel_alignment:
Expand Down
Loading
Loading