Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: omnigen initial training support attempt #1117

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Conversation

bghira
Copy link
Owner

@bghira bghira commented Nov 2, 2024

  • batch size 1 only due to padding being incorrect
  • high loss due to padding being incorrect
  • but validations work
  • LyCORIS, PEFT LoRAs work
  • no controlnet style training yet

Comment on lines 2075 to 2096
inputs = {
"x": noisy_latents,
"timestep": timesteps,
"input_ids": (
batch.get("input_ids").to(self.accelerator.device)
if batch.get("input_ids") is not None
else None
),
"input_img_latents": (
batch.get("input_img_latents").to(self.accelerator.device)
if batch.get("input_img_latents") is not None
else None
),
"input_image_sizes": batch.get("input_image_sizes"),
"attention_mask": batch.get("encoder_attention_mask").to(
self.accelerator.device
),
"position_ids": batch.get("position_ids").to(
self.accelerator.device
),
}
model_pred = self.transformer(**inputs)[0]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@staoxiao this is where i'm running the inputs for OmniGen

helpers/training/trainer.py Outdated Show resolved Hide resolved
Comment on lines 563 to 619
elif StateTracker.get_model_family() == "omnigen":
# instruction, output_image = example['instruction'], example['input_images'], example['output_image']
omnigen_processed_embeddings = compute_prompt_embeddings(
captions, text_embed_cache
)
from OmniGen.processor import OmniGenCollator

attn_mask = [e.get("attention_mask") for e in omnigen_processed_embeddings]
attn_mask_len = len(attn_mask[0][0])
attn_mask = torch.stack(attn_mask, dim=0)

# we can use the OmniGenCollator.create_position to make positional ids
num_tokens_for_output_images = []
for img_size in [
[
latent_batch.shape[3] * 8,
latent_batch.shape[2] * 8,
]
* len(latent_batch)
]:
num_img_tokens = img_size[0] * img_size[1] // 16 // 16
num_text_tokens = attn_mask_len
total_num_tokens = num_img_tokens - num_text_tokens
num_tokens_for_output_images.append(total_num_tokens)
position_ids = OmniGenCollator.create_position(
attn_mask, num_tokens_for_output_images
)
# pad attn_mask to match the position_ids, eg. mask [1, 1, 1, 57] -> [1, 1, 1, 4097]
attn_mask = torch.cat(
[
attn_mask,
torch.zeros(
(
attn_mask.shape[0],
attn_mask.shape[1],
num_tokens_for_output_images[0] + 1,
)
),
],
dim=-1,
)

# TODO: support "input images" for OmniGen which behave as conditioning images, eg. ControlNet Canny, Depth, etc.
# conditioning_pixel_values = torch.stack([e.get('input_pixel_values') for e in omnigen_processed_embeddings], dim=0)
# input_image_sizes = [e.get('input_image_size') for e in omnigen_processed_embeddings]
# extra_batch_inputs['conditioning_pixel_values'] = conditioning_pixel_values
# extra_batch_inputs['input_image_sizes'] = input_image_sizes
# input_ids = [e.get('input_ids') for e in omnigen_processed_embeddings]
# input_ids = torch.stack(input_ids, dim=0)
# TODO: Support instruction/conditioning image dropout for OmniGen.
# if random.random() < StateTracker.get_args().caption_dropout_probability:
# instruction = '<cfg>'
# latent_batch = None
padding_images = [e.get("padding_image") for e in omnigen_processed_embeddings]
extra_batch_inputs["position_ids"] = position_ids
extra_batch_inputs["padding_images"] = padding_images
extra_batch_inputs["input_ids"] = input_ids
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@staoxiao this portion of prep work is where i ran into the most confusion and i just brute forced shapes until things worked

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, our implementation of atten_mask is a little complex, integrating both unidirectional and bidirectional attention mechanisms. I suggest using our function directly: https://github.com/VectorSpaceLab/OmniGen/blob/main/OmniGen/processor.py#L241C34-L241C62

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it's just really not commented in there and there's no docstrings to help make use of it. we have to reverse the understanding of the code by tracing through it a lot.

-> x: torch.Size([1, 4, 128, 128])
-> timestep: torch.Size([1])
-> input_ids: None
-> input_img_latents: None
-> attention_mask: torch.Size([1, 4120, 4120])
-> position_ids: torch.Size([1, 4120])
The size of tensor a (4097) must match the size of tensor b (4120) at non-singleton dimension 2

these are the current input shapes we get from the OmniGenProcessor

Not sure why such weird shapes are coming back.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@staoxiao any ideas? currently stalled out on this because we can't get the loss to play nicely. is it a seq len issue with the length of the input embeddings?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants