Skip to content

Commit a00a7af

Browse files
committed
update to include five ten crop and resized crop, use placeholder transforms for flip and resize for now
1 parent 31e08e4 commit a00a7af

File tree

4 files changed

+99
-27
lines changed

4 files changed

+99
-27
lines changed

test/common_utils.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,15 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import (
24-
cvcuda_to_tensor,
25-
is_cvcuda_tensor,
26-
to_cvcuda_tensor,
27-
to_image,
28-
to_pil_image,
29-
)
23+
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
3025
from torchvision.utils import _Image_fromarray
3126

3227

3328
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
3429
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
3530
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
31+
CVCUDA_AVAILABLE = _is_cvcuda_available()
3632
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3733
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3834
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -307,11 +303,10 @@ def __init__(
307303
if isinstance(expected, PIL.Image.Image):
308304
expected = to_image(expected)
309305

310-
# attempt to convert CV-CUDA tensors to torch tensors
311-
if is_cvcuda_tensor(actual):
306+
# handle check for CV-CUDA Tensors
307+
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
308+
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
312309
actual = cvcuda_to_pil_compatible_tensor(actual)
313-
if is_cvcuda_tensor(expected):
314-
expected = cvcuda_to_pil_compatible_tensor(expected)
315310

316311
super().__init__(actual, expected, **other_parameters)
317312
self.mae = mae

test/test_transforms_v2.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torchvision.transforms.v2 as transforms
2222

2323
from common_utils import (
24+
assert_close,
2425
assert_equal,
2526
cache,
2627
cpu_and_cuda,
@@ -42,7 +43,6 @@
4243
)
4344

4445
from torch import nn
45-
from torch.testing import assert_close
4646
from torch.utils._pytree import tree_flatten, tree_map
4747
from torch.utils.data import DataLoader, default_collate
4848
from torchvision import tv_tensors
@@ -3499,11 +3499,8 @@ def test_functional_image_correctness(self, kwargs, make_input):
34993499

35003500
actual = F.crop(image, **kwargs)
35013501

3502-
if make_input == make_image_cvcuda:
3503-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
3504-
actual = actual.squeeze(0)
3505-
image = F.cvcuda_to_tensor(image).to(device="cpu")
3506-
image = image.squeeze(0)
3502+
if make_input is make_image_cvcuda:
3503+
image = cvcuda_to_pil_compatible_tensor(image)
35073504

35083505
expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs))
35093506

@@ -3624,15 +3621,15 @@ def test_transform_image_correctness(self, param, value, seed, make_input):
36243621

36253622
torch.manual_seed(seed)
36263623

3627-
if make_input == make_image_cvcuda:
3624+
if make_input is make_image_cvcuda:
36283625
image = cvcuda_to_pil_compatible_tensor(image)
36293626

36303627
expected = F.to_image(transform(F.to_pil_image(image)))
36313628

36323629
if make_input == make_image_cvcuda and will_pad:
36333630
# when padding is applied, CV-CUDA will always fill with zeros
36343631
# cannot use assert_equal since it will fail unless random is all zeros
3635-
torch.testing.assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype))
3632+
assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype))
36363633
else:
36373634
assert_equal(actual, expected)
36383635

@@ -4458,6 +4455,9 @@ def test_kernel(self, kernel, make_input):
44584455
make_segmentation_mask,
44594456
make_video,
44604457
make_keypoints,
4458+
pytest.param(
4459+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4460+
),
44614461
],
44624462
)
44634463
def test_functional(self, make_input):
@@ -4474,9 +4474,16 @@ def test_functional(self, make_input):
44744474
(F.resized_crop_mask, tv_tensors.Mask),
44754475
(F.resized_crop_video, tv_tensors.Video),
44764476
(F.resized_crop_keypoints, tv_tensors.KeyPoints),
4477+
pytest.param(
4478+
F.resized_crop_image,
4479+
"cvcuda.Tensor",
4480+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
4481+
),
44774482
],
44784483
)
44794484
def test_functional_signature(self, kernel, input_type):
4485+
if input_type == "cvcuda.Tensor":
4486+
input_type = _import_cvcuda().Tensor
44804487
check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type)
44814488

44824489
@param_value_parametrization(
@@ -4493,6 +4500,9 @@ def test_functional_signature(self, kernel, input_type):
44934500
make_segmentation_mask,
44944501
make_video,
44954502
make_keypoints,
4503+
pytest.param(
4504+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4505+
),
44964506
],
44974507
)
44984508
def test_transform(self, param, value, make_input):
@@ -4504,20 +4514,37 @@ def test_transform(self, param, value, make_input):
45044514

45054515
# `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
45064516
# The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
4517+
@pytest.mark.parametrize(
4518+
"make_input",
4519+
[
4520+
make_image,
4521+
pytest.param(
4522+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4523+
),
4524+
],
4525+
)
45074526
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
4508-
def test_functional_image_correctness(self, interpolation):
4509-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
4527+
def test_functional_image_correctness(self, make_input, interpolation):
4528+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8)
45104529

45114530
actual = F.resized_crop(
45124531
image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True
45134532
)
4533+
4534+
if make_input is make_image_cvcuda:
4535+
image = cvcuda_to_pil_compatible_tensor(image)
4536+
45144537
expected = F.to_image(
45154538
F.resized_crop(
45164539
F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation
45174540
)
45184541
)
45194542

4520-
torch.testing.assert_close(actual, expected, atol=1, rtol=0)
4543+
atol = 1
4544+
if make_input is make_image_cvcuda and interpolation == transforms.InterpolationMode.BICUBIC:
4545+
# CV-CUDA BICUBIC differs from PIL ground truth BICUBIC
4546+
atol = 10
4547+
assert_close(actual, expected, atol=atol, rtol=0)
45214548

45224549
def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size):
45234550
new_height, new_width = size
@@ -4992,7 +5019,7 @@ def test_image_correctness(self, output_size, make_input, fn):
49925019

49935020
actual = fn(image, output_size)
49945021

4995-
if make_input == make_image_cvcuda:
5022+
if make_input is make_image_cvcuda:
49965023
image = cvcuda_to_pil_compatible_tensor(image)
49975024

49985025
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))

torchvision/transforms/v2/_geometry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class RandomHorizontalFlip(_RandomApplyTransform):
4646

4747
_v1_transform_cls = _transforms.RandomHorizontalFlip
4848

49+
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)
50+
4951
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
5052
return self._call_kernel(F.horizontal_flip, inpt)
5153

@@ -64,6 +66,8 @@ class RandomVerticalFlip(_RandomApplyTransform):
6466

6567
_v1_transform_cls = _transforms.RandomVerticalFlip
6668

69+
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)
70+
6771
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
6872
return self._call_kernel(F.vertical_flip, inpt)
6973

@@ -247,6 +251,8 @@ class RandomResizedCrop(Transform):
247251

248252
_v1_transform_cls = _transforms.RandomResizedCrop
249253

254+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
255+
250256
def __init__(
251257
self,
252258
size: Union[int, Sequence[int]],

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,32 @@ def resize_video(
605605
return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
606606

607607

608+
def _resize_cvcuda(
609+
image: "cvcuda.Tensor",
610+
size: Optional[list[int]],
611+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
612+
max_size: Optional[int] = None,
613+
antialias: Optional[bool] = True,
614+
) -> "cvcuda.Tensor":
615+
# placeholder func for now, will be handled in PR for resize alone
616+
# since placeholder convert to from torch tensor and use resize_image
617+
from ._type_conversion import cvcuda_to_tensor, to_cvcuda_tensor
618+
619+
return to_cvcuda_tensor(
620+
resize_image(
621+
cvcuda_to_tensor(image),
622+
size=size,
623+
interpolation=interpolation,
624+
max_size=max_size,
625+
antialias=antialias,
626+
)
627+
)
628+
629+
630+
if CVCUDA_AVAILABLE:
631+
_register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_cvcuda)
632+
633+
608634
def affine(
609635
inpt: torch.Tensor,
610636
angle: Union[int, float],
@@ -2946,6 +2972,24 @@ def resized_crop_video(
29462972
)
29472973

29482974

2975+
def _resized_crop_cvcuda(
2976+
image: "cvcuda.Tensor",
2977+
top: int,
2978+
left: int,
2979+
height: int,
2980+
width: int,
2981+
size: list[int],
2982+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2983+
antialias: Optional[bool] = True,
2984+
) -> "cvcuda.Tensor":
2985+
image = _crop_cvcuda(image, top, left, height, width)
2986+
return _resize_cvcuda(image, size, interpolation=interpolation, antialias=antialias)
2987+
2988+
2989+
if CVCUDA_AVAILABLE:
2990+
_register_kernel_internal(resized_crop, _import_cvcuda().Tensor)(_resized_crop_cvcuda)
2991+
2992+
29492993
def five_crop(
29502994
inpt: torch.Tensor, size: list[int]
29512995
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -3024,15 +3068,15 @@ def _five_crop_cvcuda(
30243068
size: list[int],
30253069
) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]:
30263070
crop_height, crop_width = _parse_five_crop_size(size)
3027-
image_height, image_width = image.shape[-2:]
3071+
image_height, image_width = image.shape[1], image.shape[2]
30283072

30293073
if crop_width > image_width or crop_height > image_height:
30303074
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
30313075

30323076
tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width)
3033-
tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height)
3034-
bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height)
3035-
br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height)
3077+
tr = _crop_cvcuda(image, 0, image_width - crop_width, crop_height, crop_width)
3078+
bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_height, crop_width)
3079+
br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
30363080
center = _center_crop_cvcuda(image, [crop_height, crop_width])
30373081

30383082
return tl, tr, bl, br, center

0 commit comments

Comments
 (0)