Skip to content

Commit

Permalink
Port remaining transforms tests (#7954)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Sep 26, 2023
1 parent 997384c commit 1a9ff0d
Show file tree
Hide file tree
Showing 9 changed files with 730 additions and 1,947 deletions.
104 changes: 0 additions & 104 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,57 +272,6 @@ def test_common(self, transform, adapter, container_type, image_or_video, de_ser
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)

@parametrize(
[
(
transform,
itertools.chain.from_iterable(
fn(
color_spaces=[
"GRAY",
"RGB",
],
dtypes=[torch.uint8],
extra_dims=[(), (4,)],
**(dict(num_frames=[3]) if fn is make_videos else dict()),
)
for fn in [
make_images,
make_vanilla_tensor_images,
make_pil_images,
make_videos,
]
),
)
for transform in (
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AutoAugment(),
transforms.AugMix(),
)
]
)
def test_auto_augment(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=["RGB"], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
make_videos,
]
),
),
]
)
def test_normalize(self, transform, input):
transform(input)


@pytest.mark.parametrize(
"flat_inputs",
Expand Down Expand Up @@ -385,40 +334,6 @@ def was_applied(output, inpt):
assert transform.was_applied(output, input)


class TestElasticTransform:
def test_assertions(self):

with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"):
transforms.ElasticTransform({})

with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"):
transforms.ElasticTransform([1.0, 2.0, 3.0])

with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
transforms.ElasticTransform(1.0, {})

with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"):
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])

with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.ElasticTransform(1.0, 2.0, fill="abc")

def test__get_params(self):
alpha = 2.0
sigma = 3.0
transform = transforms.ElasticTransform(alpha, sigma)

h, w = size = (24, 32)
image = make_image(size)

params = transform._get_params([image])

displacement = params["displacement"]
assert displacement.shape == (1, h, w, 2)
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()


class TestTransform:
@pytest.mark.parametrize(
"inpt_type",
Expand Down Expand Up @@ -705,25 +620,6 @@ def test__get_params(self):
assert min_size <= size < max_size


class TestUniformTemporalSubsample:
@pytest.mark.parametrize(
"inpt",
[
torch.zeros(10, 3, 8, 8),
torch.zeros(1, 10, 3, 8, 8),
tv_tensors.Video(torch.zeros(1, 10, 3, 8, 8)),
],
)
def test__transform(self, inpt):
num_samples = 5
transform = transforms.UniformTemporalSubsample(num_samples)

output = transform(inpt)
assert type(output) is type(inpt)
assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
Expand Down
130 changes: 0 additions & 130 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,34 +72,6 @@ def __init__(
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.Normalize,
legacy_transforms.Normalize,
[
ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
],
supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
),
ConsistencyConfig(
v2_transforms.FiveCrop,
legacy_transforms.FiveCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
v2_transforms.TenCrop,
legacy_transforms.TenCrop,
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
ArgsKwargs(18, vertical_flip=True),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
*[
ConsistencyConfig(
v2_transforms.LinearTransformation,
Expand Down Expand Up @@ -147,65 +119,6 @@ def __init__(
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.RandomEqualize,
legacy_transforms.RandomEqualize,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
v2_transforms.RandomInvert,
legacy_transforms.RandomInvert,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
v2_transforms.RandomPosterize,
legacy_transforms.RandomPosterize,
[
ArgsKwargs(p=0, bits=5),
ArgsKwargs(p=1, bits=1),
ArgsKwargs(p=1, bits=3),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
v2_transforms.RandomSolarize,
legacy_transforms.RandomSolarize,
[
ArgsKwargs(p=0, threshold=0.5),
ArgsKwargs(p=1, threshold=0.3),
ArgsKwargs(p=1, threshold=0.99),
],
),
*[
ConsistencyConfig(
v2_transforms.RandomAutocontrast,
legacy_transforms.RandomAutocontrast,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
],
ConsistencyConfig(
v2_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness,
[
ArgsKwargs(p=0, sharpness_factor=0.5),
ArgsKwargs(p=1, sharpness_factor=0.2),
ArgsKwargs(p=1, sharpness_factor=0.99),
],
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
),
ConsistencyConfig(
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
Expand All @@ -230,22 +143,6 @@ def __init__(
v2_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
ConsistencyConfig(
v2_transforms.AugMix,
legacy_transforms.AugMix,
),
ConsistencyConfig(
v2_transforms.AutoAugment,
legacy_transforms.AutoAugment,
),
ConsistencyConfig(
v2_transforms.RandAugment,
legacy_transforms.RandAugment,
),
ConsistencyConfig(
v2_transforms.TrivialAugmentWide,
legacy_transforms.TrivialAugmentWide,
),
]


Expand Down Expand Up @@ -753,36 +650,9 @@ def test_common(self, t_ref, t, data_kwargs):
(legacy_F.pil_to_tensor, {}),
(legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}),
(legacy_F.resize, {"interpolation"}),
(legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}),
(legacy_F.center_crop, {}),
(legacy_F.resized_crop, {"interpolation"}),
(legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
(legacy_F.vflip, {}),
(legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}),
(legacy_F.adjust_brightness, {}),
(legacy_F.adjust_contrast, {}),
(legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill", "interpolation"}),
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
(legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}),
(legacy_F.erase, {}),
(legacy_F.gaussian_blur, {}),
(legacy_F.invert, {}),
(legacy_F.posterize, {}),
(legacy_F.solarize, {}),
(legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}),
(legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill", "interpolation"}),
],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
Expand Down
Loading

0 comments on commit 1a9ff0d

Please sign in to comment.