Skip to content

Commit

Permalink
[Fix] Fix num_images_per_prompt in controlnet (#1936)
Browse files Browse the repository at this point in the history
* fix num_images_per_prompt in controlnet + revise ut

* remove one unit test due to memory limit

* remove one unit test due to memory limit
  • Loading branch information
LeoXing1996 authored Jul 13, 2023
1 parent d43496c commit 9d37453
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 10 deletions.
7 changes: 5 additions & 2 deletions mmagic/models/editors/controlnet/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor,
Args:
image (Tuple[Image.Image, List[Image.Image], Tensor, List[Tensor]]): # noqa
The input image for control.
batch_size (int): The batch size of the control. The control will
batch_size (int): The number of the prompt. The control will
be repeated for `batch_size` times.
num_images_per_prompt (int): The number images generate for one
prompt.
Expand Down Expand Up @@ -364,8 +364,11 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor,
image_batch_size = image.shape[0]

if image_batch_size == 1:
repeat_by = batch_size
repeat_by = batch_size * num_images_per_prompt
else:
assert image_batch_size == batch_size, (
'The number of Control condition must be 1 or equal to the '
'number of prompt.')
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt

Expand Down
55 changes: 47 additions & 8 deletions tests/test_models/test_editors/test_controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,49 @@ def test_init_weights(self):

def test_infer(self):
control_sd = self.control_sd
control = torch.ones([1, 3, 64, 64])

def mock_encode_prompt(*args, **kwargs):
return torch.randn(2, 5, 16) # 2 for cfg
def mock_encode_prompt(prompt, do_classifier_free_guidance,
num_images_per_prompt, *args, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
batch_size *= num_images_per_prompt
if do_classifier_free_guidance:
batch_size *= 2
return torch.randn(batch_size, 5, 16) # 2 for cfg

encode_prompt = control_sd._encode_prompt
control_sd._encode_prompt = mock_encode_prompt

# one prompt, one control, repeat 1 time
self._test_infer(control_sd, 1, 1, 1, 1)

# two prompt, two control, repeat 1 time
# NOTE: skip this due to memory limit
# self._test_infer(control_sd, 2, 2, 1, 2)

# one prompt, one control, repeat 2 times
# NOTE: skip this due to memory limit
# self._test_infer(control_sd, 1, 1, 2, 2)

# two prompt, two control, repeat 2 times
# NOTE: skip this due to memory limit
# self._test_infer(control_sd, 2, 2, 2, 4)

control_sd._encode_prompt = encode_prompt

def _test_infer(self, control_sd, num_prompt, num_control, num_repeat,
tar_shape):
prompt = ''
control = torch.ones([1, 3, 64, 64])

result = control_sd.infer(
'an insect robot preparing a delicious meal',
control=control,
[prompt] * num_prompt,
control=[control] * num_control,
height=64,
width=64,
num_images_per_prompt=num_repeat,
num_inference_steps=1,
return_type='numpy')
assert result['samples'].shape == (1, 3, 64, 64)

control_sd._encode_prompt = encode_prompt
assert result['samples'].shape == (tar_shape, 3, 64, 64)

def test_val_step(self):
control_sd = self.control_sd
Expand All @@ -126,13 +151,20 @@ def test_val_step(self):
def mock_encode_prompt(*args, **kwargs):
return torch.randn(2, 5, 16) # 2 for cfg

def mock_infer(*args, **kwargs):
return dict(samples=torch.randn(2, 3, 64, 64))

encode_prompt = control_sd._encode_prompt
control_sd._encode_prompt = mock_encode_prompt

infer = control_sd.infer
control_sd.infer = mock_infer

# control_sd.text_encoder = mock_text_encoder()
output = control_sd.val_step(data)
assert len(output) == 1
control_sd._encode_prompt = encode_prompt
control_sd.infer = infer

def test_test_step(self):
control_sd = self.control_sd
Expand All @@ -149,13 +181,20 @@ def test_test_step(self):
def mock_encode_prompt(*args, **kwargs):
return torch.randn(2, 5, 16) # 2 for cfg

def mock_infer(*args, **kwargs):
return dict(samples=torch.randn(2, 3, 64, 64))

encode_prompt = control_sd._encode_prompt
control_sd._encode_prompt = mock_encode_prompt

infer = control_sd.infer
control_sd.infer = mock_infer

# control_sd.text_encoder = mock_text_encoder()
output = control_sd.test_step(data)
assert len(output) == 1
control_sd._encode_prompt = encode_prompt
control_sd.infer = infer

def test_train_step(self):
control_sd = self.control_sd
Expand Down

0 comments on commit 9d37453

Please sign in to comment.