Skip to content

Commit 778ad32

Browse files
committed
update normalize based on PR reviews
1 parent d93938e commit 778ad32

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

test/common_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
23+
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
2425
from torchvision.utils import _Image_fromarray
2526

2627

2728
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
2829
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
2930
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
31+
CVCUDA_AVAILABLE = _is_cvcuda_available()
3032
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3133
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3234
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
275277
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
276278

277279

280+
def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
281+
tensor = cvcuda_to_tensor(tensor)
282+
if tensor.ndim != 4:
283+
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
284+
if tensor.shape[0] != 1:
285+
raise ValueError(
286+
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
287+
)
288+
return tensor.squeeze(0).cpu()
289+
290+
278291
class ImagePair(TensorLikePair):
279292
def __init__(
280293
self,
@@ -287,6 +300,11 @@ def __init__(
287300
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
288301
actual, expected = (to_image(input) for input in [actual, expected])
289302

303+
# handle check for CV-CUDA Tensors
304+
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
305+
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
306+
actual = cvcuda_to_pil_compatible_tensor(actual)
307+
290308
super().__init__(actual, expected, **other_parameters)
291309
self.mae = mae
292310

@@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs):
401419

402420

403421
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
404-
# explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
405422
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))
406423

407424

test/test_transforms_v2.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import torchvision.transforms.v2 as transforms
2222

2323
from common_utils import (
24+
assert_close,
2425
assert_equal,
2526
cache,
2627
cpu_and_cuda,
28+
cvcuda_to_pil_compatible_tensor,
2729
freeze_rng_state,
2830
ignore_jit_no_profile_information_warning,
2931
make_bounding_boxes,
@@ -41,7 +43,6 @@
4143
)
4244

4345
from torch import nn
44-
from torch.testing import assert_close
4546
from torch.utils._pytree import tree_flatten, tree_map
4647
from torch.utils.data import DataLoader, default_collate
4748
from torchvision import tv_tensors
@@ -5500,17 +5501,17 @@ def test_kernel_image(self, mean, std, device):
55005501

55015502
@pytest.mark.parametrize("device", cpu_and_cuda())
55025503
def test_kernel_image_inplace(self, device):
5503-
input = make_image_tensor(dtype=torch.float32, device=device)
5504-
input_version = input._version
5504+
inpt = make_image_tensor(dtype=torch.float32, device=device)
5505+
input_version = inpt._version
55055506

5506-
output_out_of_place = F.normalize_image(input, mean=self.MEAN, std=self.STD)
5507-
assert output_out_of_place.data_ptr() != input.data_ptr()
5508-
assert output_out_of_place is not input
5507+
output_out_of_place = F.normalize_image(inpt, mean=self.MEAN, std=self.STD)
5508+
assert output_out_of_place.data_ptr() != inpt.data_ptr()
5509+
assert output_out_of_place is not inpt
55095510

5510-
output_inplace = F.normalize_image(input, mean=self.MEAN, std=self.STD, inplace=True)
5511-
assert output_inplace.data_ptr() == input.data_ptr()
5511+
output_inplace = F.normalize_image(inpt, mean=self.MEAN, std=self.STD, inplace=True)
5512+
assert output_inplace.data_ptr() == inpt.data_ptr()
55125513
assert output_inplace._version > input_version
5513-
assert output_inplace is input
5514+
assert output_inplace is inpt
55145515

55155516
assert_equal(output_inplace, output_out_of_place)
55165517

@@ -5560,9 +5561,9 @@ def test_functional_error(self):
55605561
with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"):
55615562
F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std)
55625563

5563-
def _sample_input_adapter(self, transform, input, device):
5564+
def _sample_input_adapter(self, transform, inpt, device):
55645565
adapted_input = {}
5565-
for key, value in input.items():
5566+
for key, value in inpt.items():
55665567
if isinstance(value, PIL.Image.Image):
55675568
# normalize doesn't support PIL images
55685569
continue
@@ -5616,15 +5617,12 @@ def test_correctness_image(self, mean, std, dtype, make_input, fn):
56165617
actual = fn(image, mean=mean, std=std)
56175618

56185619
if make_input == make_image_cvcuda:
5619-
image = F.cvcuda_to_tensor(image).to(device="cpu")
5620-
image = image.squeeze(0)
5621-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
5622-
actual = actual.squeeze(0)
5620+
image = cvcuda_to_pil_compatible_tensor(image)
56235621

56245622
expected = self._reference_normalize_image(image, mean=mean, std=std)
56255623

56265624
if make_input == make_image_cvcuda:
5627-
torch.testing.assert_close(actual, expected, rtol=0, atol=1e-6)
5625+
assert_close(actual, expected, rtol=0, atol=1e-6)
56285626
else:
56295627
assert_equal(actual, expected)
56305628

torchvision/transforms/v2/_misc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_bounding_boxes,
1818
get_keypoints,
1919
has_any,
20+
is_cvcuda_tensor,
2021
is_pure_tensor,
2122
)
2223

@@ -160,6 +161,8 @@ class Normalize(Transform):
160161

161162
_v1_transform_cls = _transforms.Normalize
162163

164+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
165+
163166
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
164167
super().__init__()
165168
self.mean = list(mean)

0 commit comments

Comments
 (0)