Skip to content

Commit

Permalink
[fbsync] draw_keypoints() float support (#8276)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D55062794

fbshipit-source-id: 1a9484e4959fef604153857cc7d4a6d7262cbea9

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Mar 20, 2024
1 parent 9613867 commit d2f3a61
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
16 changes: 16 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default():
assert_equal(result, expected)


def test_draw_keypoints_dtypes():
image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
image_float = to_dtype(image_uint8, torch.float, scale=True)

out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
out_float = utils.draw_keypoints(image_float, keypoints)

assert out_uint8.dtype == torch.uint8
assert out_uint8 is not image_uint8

assert out_float.is_floating_point()
assert out_float is not image_float

torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
Expand Down
21 changes: 15 additions & 6 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,13 @@ def draw_keypoints(

"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].
Keypoints can be drawn for multiple instances at a time.
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
Expand All @@ -363,16 +363,16 @@ def draw_keypoints(
For more details, see :ref:`draw_keypoints_with_visibility`.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
# validate image
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
Expand All @@ -397,6 +397,12 @@ def draw_keypoints(
f"Got {visibility.shape = } and {keypoints.shape = }"
)

original_dtype = image.dtype
if original_dtype.is_floating_point:
from torchvision.transforms.v2.functional import to_dtype # noqa

image = to_dtype(image, dtype=torch.uint8, scale=True)

ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
Expand Down Expand Up @@ -428,7 +434,10 @@ def draw_keypoints(
width=width,
)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
if original_dtype.is_floating_point:
out = to_dtype(out, dtype=original_dtype, scale=True)
return out


# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
Expand Down

0 comments on commit d2f3a61

Please sign in to comment.