Skip to content

Commit d93938e

Browse files
committed
simplify normalize testing into single test parameterize on input creation
1 parent c16a033 commit d93938e

File tree

5 files changed

+38
-18
lines changed

5 files changed

+38
-18
lines changed

test/test_transforms_v2.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5597,24 +5597,36 @@ def _reference_normalize_image(self, image, *, mean, std):
55975597

55985598
@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
55995599
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
5600+
@pytest.mark.parametrize(
5601+
"make_input",
5602+
[
5603+
make_image,
5604+
pytest.param(
5605+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5606+
),
5607+
],
5608+
)
56005609
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
5601-
def test_correctness_image(self, mean, std, dtype, fn):
5602-
image = make_image(dtype=dtype)
5610+
def test_correctness_image(self, mean, std, dtype, make_input, fn):
5611+
if make_input == make_image_cvcuda and dtype != torch.float32:
5612+
pytest.skip("CVCUDA only supports float32 for normalize")
5613+
5614+
image = make_input(dtype=dtype)
56035615

56045616
actual = fn(image, mean=mean, std=std)
5605-
expected = self._reference_normalize_image(image, mean=mean, std=std)
56065617

5607-
assert_equal(actual, expected)
5618+
if make_input == make_image_cvcuda:
5619+
image = F.cvcuda_to_tensor(image).to(device="cpu")
5620+
image = image.squeeze(0)
5621+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
5622+
actual = actual.squeeze(0)
56085623

5609-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5610-
@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
5611-
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
5612-
def test_correctness_cvcuda(self, mean, std, fn):
5613-
image = make_image(batch_dims=(1,), dtype=torch.float32, device="cuda")
5614-
cvc_image = F.to_cvcuda_tensor(image)
5615-
actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std)
5616-
expected = fn(image, mean=mean, std=std)
5617-
torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=1e-7, atol=1e-7)
5624+
expected = self._reference_normalize_image(image, mean=mean, std=std)
5625+
5626+
if make_input == make_image_cvcuda:
5627+
torch.testing.assert_close(actual, expected, rtol=0, atol=1e-6)
5628+
else:
5629+
assert_equal(actual, expected)
56185630

56195631

56205632
class TestClampBoundingBoxes:

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import nn
99
from torch.utils._pytree import tree_flatten, tree_unflatten
1010
from torchvision import tv_tensors
11-
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
11+
from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

1414
from .functional._utils import _get_kernel
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

torchvision/transforms/v2/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchvision._utils import sequence_to_str
1616

1717
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
18-
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
18+
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
1919
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
2020

2121

@@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
182182
chws = {
183183
tuple(get_dimensions(inpt))
184184
for inpt in flat_inputs
185-
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
185+
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
186186
}
187187
if not chws:
188188
raise TypeError("No image or video was found in the sample")
@@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
207207
tv_tensors.Mask,
208208
tv_tensors.BoundingBoxes,
209209
tv_tensors.KeyPoints,
210+
is_cvcuda_tensor,
210211
),
211212
)
212213
}

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22

3-
from ._utils import is_pure_tensor, register_kernel # usort: skip
3+
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip
44

55
from ._meta import (
66
clamp_bounding_boxes,

torchvision/transforms/v2/functional/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,10 @@ def _is_cvcuda_available():
169169
return True
170170
except ImportError:
171171
return False
172+
173+
174+
def is_cvcuda_tensor(inpt: Any) -> bool:
175+
if _is_cvcuda_available():
176+
cvcuda = _import_cvcuda()
177+
return isinstance(inpt, cvcuda.Tensor)
178+
return False

0 commit comments

Comments
 (0)