Skip to content

Commit

Permalink
port sample input smoke test
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Sep 14, 2023
1 parent 775e5ec commit 7b1ef58
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 84 deletions.
4 changes: 2 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=tv_tensors.BoundingBoxFormat.XYXY,
num_objects=1,
dtype=None,
device="cpu",
):
Expand All @@ -419,8 +420,7 @@ def sample_position(values, max_value):

dtype = dtype or torch.float32

num_objects = 1
h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size]
h, w = [torch.randint(1, s, (num_objects,)) for s in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])

Expand Down
60 changes: 1 addition & 59 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw
from torchvision.transforms.v2._utils import is_pure_tensor, query_chw
from transforms_v2_legacy_utils import (
make_bounding_boxes,
make_detection_mask,
Expand Down Expand Up @@ -62,22 +62,6 @@ def parametrize(transforms_with_inputs):
)


def auto_augment_adapter(transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input


def linear_transformation_adapter(transform, input, device):
flat_inputs = list(input.values())
c, h, w = query_chw(
Expand All @@ -93,58 +77,19 @@ def linear_transformation_adapter(transform, input, device):
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}


def normalize_adapter(transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
return adapted_input


class TestSmoke:
@pytest.mark.parametrize(
("transform", "adapter"),
[
(transforms.RandomErasing(p=1.0), None),
(transforms.AugMix(), auto_augment_adapter),
(transforms.AutoAugment(), auto_augment_adapter),
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.Grayscale(), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
(transforms.RandomAutocontrast(p=1.0), None),
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None),
(transforms.ElasticTransform(sigma=1.0), None),
(transforms.Pad(4), None),
(transforms.RandomAffine(degrees=30.0), None),
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
(transforms.RandomHorizontalFlip(p=1.0), None),
(transforms.RandomPerspective(p=1.0), None),
(transforms.RandomResize(min_size=10, max_size=20, antialias=True), None),
(transforms.RandomResizedCrop([16, 16], antialias=True), None),
(transforms.RandomRotation(degrees=30), None),
(transforms.RandomShortestSize(min_size=10, antialias=True), None),
(transforms.RandomVerticalFlip(p=1.0), None),
(transforms.RandomZoomOut(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
Expand All @@ -154,9 +99,6 @@ class TestSmoke:
),
linear_transformation_adapter,
),
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
(transforms.ToDtype(torch.float64), None),
(transforms.UniformTemporalSubsample(num_samples=2), None),
],
ids=lambda transform: type(transform).__name__,
)
Expand Down
Loading

0 comments on commit 7b1ef58

Please sign in to comment.