diff --git a/mmagic/models/editors/controlnet/controlnet.py b/mmagic/models/editors/controlnet/controlnet.py index dfce2c44ed..be23afcc07 100644 --- a/mmagic/models/editors/controlnet/controlnet.py +++ b/mmagic/models/editors/controlnet/controlnet.py @@ -334,7 +334,7 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor, Args: image (Tuple[Image.Image, List[Image.Image], Tensor, List[Tensor]]): # noqa The input image for control. - batch_size (int): The batch size of the control. The control will + batch_size (int): The number of the prompt. The control will be repeated for `batch_size` times. num_images_per_prompt (int): The number images generate for one prompt. @@ -364,8 +364,11 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor, image_batch_size = image.shape[0] if image_batch_size == 1: - repeat_by = batch_size + repeat_by = batch_size * num_images_per_prompt else: + assert image_batch_size == batch_size, ( + 'The number of Control condition must be 1 or equal to the ' + 'number of prompt.') # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt diff --git a/tests/test_models/test_editors/test_controlnet/test_controlnet.py b/tests/test_models/test_editors/test_controlnet/test_controlnet.py index 0263a7528e..d29bc1ebcf 100644 --- a/tests/test_models/test_editors/test_controlnet/test_controlnet.py +++ b/tests/test_models/test_editors/test_controlnet/test_controlnet.py @@ -92,24 +92,49 @@ def test_init_weights(self): def test_infer(self): control_sd = self.control_sd - control = torch.ones([1, 3, 64, 64]) - def mock_encode_prompt(*args, **kwargs): - return torch.randn(2, 5, 16) # 2 for cfg + def mock_encode_prompt(prompt, do_classifier_free_guidance, + num_images_per_prompt, *args, **kwargs): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + batch_size *= num_images_per_prompt + if do_classifier_free_guidance: + batch_size *= 2 + return torch.randn(batch_size, 5, 16) # 2 for cfg encode_prompt = control_sd._encode_prompt control_sd._encode_prompt = mock_encode_prompt + # one prompt, one control, repeat 1 time + self._test_infer(control_sd, 1, 1, 1, 1) + + # two prompt, two control, repeat 1 time + # NOTE: skip this due to memory limit + # self._test_infer(control_sd, 2, 2, 1, 2) + + # one prompt, one control, repeat 2 times + # NOTE: skip this due to memory limit + # self._test_infer(control_sd, 1, 1, 2, 2) + + # two prompt, two control, repeat 2 times + # NOTE: skip this due to memory limit + # self._test_infer(control_sd, 2, 2, 2, 4) + + control_sd._encode_prompt = encode_prompt + + def _test_infer(self, control_sd, num_prompt, num_control, num_repeat, + tar_shape): + prompt = '' + control = torch.ones([1, 3, 64, 64]) + result = control_sd.infer( - 'an insect robot preparing a delicious meal', - control=control, + [prompt] * num_prompt, + control=[control] * num_control, height=64, width=64, + num_images_per_prompt=num_repeat, num_inference_steps=1, return_type='numpy') - assert result['samples'].shape == (1, 3, 64, 64) - - control_sd._encode_prompt = encode_prompt + assert result['samples'].shape == (tar_shape, 3, 64, 64) def test_val_step(self): control_sd = self.control_sd @@ -126,13 +151,20 @@ def test_val_step(self): def mock_encode_prompt(*args, **kwargs): return torch.randn(2, 5, 16) # 2 for cfg + def mock_infer(*args, **kwargs): + return dict(samples=torch.randn(2, 3, 64, 64)) + encode_prompt = control_sd._encode_prompt control_sd._encode_prompt = mock_encode_prompt + infer = control_sd.infer + control_sd.infer = mock_infer + # control_sd.text_encoder = mock_text_encoder() output = control_sd.val_step(data) assert len(output) == 1 control_sd._encode_prompt = encode_prompt + control_sd.infer = infer def test_test_step(self): control_sd = self.control_sd @@ -149,13 +181,20 @@ def test_test_step(self): def mock_encode_prompt(*args, **kwargs): return torch.randn(2, 5, 16) # 2 for cfg + def mock_infer(*args, **kwargs): + return dict(samples=torch.randn(2, 3, 64, 64)) + encode_prompt = control_sd._encode_prompt control_sd._encode_prompt = mock_encode_prompt + infer = control_sd.infer + control_sd.infer = mock_infer + # control_sd.text_encoder = mock_text_encoder() output = control_sd.test_step(data) assert len(output) == 1 control_sd._encode_prompt = encode_prompt + control_sd.infer = infer def test_train_step(self): control_sd = self.control_sd