-
Notifications
You must be signed in to change notification settings - Fork 182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
wip: omnigen initial training support attempt #1117
base: main
Are you sure you want to change the base?
Conversation
bghira
commented
Nov 2, 2024
- batch size 1 only due to padding being incorrect
- high loss due to padding being incorrect
- but validations work
- LyCORIS, PEFT LoRAs work
- no controlnet style training yet
c9883cb
to
e1122fd
Compare
helpers/training/trainer.py
Outdated
inputs = { | ||
"x": noisy_latents, | ||
"timestep": timesteps, | ||
"input_ids": ( | ||
batch.get("input_ids").to(self.accelerator.device) | ||
if batch.get("input_ids") is not None | ||
else None | ||
), | ||
"input_img_latents": ( | ||
batch.get("input_img_latents").to(self.accelerator.device) | ||
if batch.get("input_img_latents") is not None | ||
else None | ||
), | ||
"input_image_sizes": batch.get("input_image_sizes"), | ||
"attention_mask": batch.get("encoder_attention_mask").to( | ||
self.accelerator.device | ||
), | ||
"position_ids": batch.get("position_ids").to( | ||
self.accelerator.device | ||
), | ||
} | ||
model_pred = self.transformer(**inputs)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@staoxiao this is where i'm running the inputs for OmniGen
helpers/training/collate.py
Outdated
elif StateTracker.get_model_family() == "omnigen": | ||
# instruction, output_image = example['instruction'], example['input_images'], example['output_image'] | ||
omnigen_processed_embeddings = compute_prompt_embeddings( | ||
captions, text_embed_cache | ||
) | ||
from OmniGen.processor import OmniGenCollator | ||
|
||
attn_mask = [e.get("attention_mask") for e in omnigen_processed_embeddings] | ||
attn_mask_len = len(attn_mask[0][0]) | ||
attn_mask = torch.stack(attn_mask, dim=0) | ||
|
||
# we can use the OmniGenCollator.create_position to make positional ids | ||
num_tokens_for_output_images = [] | ||
for img_size in [ | ||
[ | ||
latent_batch.shape[3] * 8, | ||
latent_batch.shape[2] * 8, | ||
] | ||
* len(latent_batch) | ||
]: | ||
num_img_tokens = img_size[0] * img_size[1] // 16 // 16 | ||
num_text_tokens = attn_mask_len | ||
total_num_tokens = num_img_tokens - num_text_tokens | ||
num_tokens_for_output_images.append(total_num_tokens) | ||
position_ids = OmniGenCollator.create_position( | ||
attn_mask, num_tokens_for_output_images | ||
) | ||
# pad attn_mask to match the position_ids, eg. mask [1, 1, 1, 57] -> [1, 1, 1, 4097] | ||
attn_mask = torch.cat( | ||
[ | ||
attn_mask, | ||
torch.zeros( | ||
( | ||
attn_mask.shape[0], | ||
attn_mask.shape[1], | ||
num_tokens_for_output_images[0] + 1, | ||
) | ||
), | ||
], | ||
dim=-1, | ||
) | ||
|
||
# TODO: support "input images" for OmniGen which behave as conditioning images, eg. ControlNet Canny, Depth, etc. | ||
# conditioning_pixel_values = torch.stack([e.get('input_pixel_values') for e in omnigen_processed_embeddings], dim=0) | ||
# input_image_sizes = [e.get('input_image_size') for e in omnigen_processed_embeddings] | ||
# extra_batch_inputs['conditioning_pixel_values'] = conditioning_pixel_values | ||
# extra_batch_inputs['input_image_sizes'] = input_image_sizes | ||
# input_ids = [e.get('input_ids') for e in omnigen_processed_embeddings] | ||
# input_ids = torch.stack(input_ids, dim=0) | ||
# TODO: Support instruction/conditioning image dropout for OmniGen. | ||
# if random.random() < StateTracker.get_args().caption_dropout_probability: | ||
# instruction = '<cfg>' | ||
# latent_batch = None | ||
padding_images = [e.get("padding_image") for e in omnigen_processed_embeddings] | ||
extra_batch_inputs["position_ids"] = position_ids | ||
extra_batch_inputs["padding_images"] = padding_images | ||
extra_batch_inputs["input_ids"] = input_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@staoxiao this portion of prep work is where i ran into the most confusion and i just brute forced shapes until things worked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, our implementation of atten_mask
is a little complex, integrating both unidirectional and bidirectional attention mechanisms. I suggest using our function directly: https://github.com/VectorSpaceLab/OmniGen/blob/main/OmniGen/processor.py#L241C34-L241C62
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. it's just really not commented in there and there's no docstrings to help make use of it. we have to reverse the understanding of the code by tracing through it a lot.
-> x: torch.Size([1, 4, 128, 128])
-> timestep: torch.Size([1])
-> input_ids: None
-> input_img_latents: None
-> attention_mask: torch.Size([1, 4120, 4120])
-> position_ids: torch.Size([1, 4120])
The size of tensor a (4097) must match the size of tensor b (4120) at non-singleton dimension 2
these are the current input shapes we get from the OmniGenProcessor
Not sure why such weird shapes are coming back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@staoxiao any ideas? currently stalled out on this because we can't get the loss to play nicely. is it a seq len issue with the length of the input embeddings?
…g model designation