Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 99 additions & 6 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -812,6 +812,9 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, size, make_input):
Expand All @@ -835,9 +838,16 @@ def test_functional(self, size, make_input):
(F.resize_mask, tv_tensors.Mask),
(F.resize_video, tv_tensors.Video),
(F.resize_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._resize_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._resize_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("size", OUTPUT_SIZES)
Expand All @@ -853,6 +863,9 @@ def test_functional_signature(self, kernel, input_type):
make_detection_masks,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, size, device, make_input):
Expand All @@ -870,23 +883,72 @@ def _check_output_size(self, input, output, *, size, max_size):
input_size=F.get_size(input), size=size, max_size=max_size
)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("size", OUTPUT_SIZES)
# `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
# The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
@pytest.mark.parametrize("use_max_size", [True, False])
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
def test_image_correctness(self, size, interpolation, use_max_size, fn):
def test_image_correctness(self, make_input, size, interpolation, use_max_size, fn):
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
return

image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
image = make_input(self.INPUT_SIZE, dtype=torch.uint8)

actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg))

self._check_output_size(image, actual, size=size, **max_size_kwarg)
torch.testing.assert_close(actual, expected, atol=1, rtol=0)

atol = 1
# when using antialias, CV-CUDA is different for BICUBIC and BILINEAR, since antialias requires hq_resize
if make_input is make_image_cvcuda and (
interpolation is transforms.InterpolationMode.BILINEAR
or interpolation is transforms.InterpolationMode.BICUBIC
):
atol = 9
assert_close(actual, expected, atol=atol, rtol=0)

@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
@pytest.mark.parametrize("use_max_size", [True, False])
@pytest.mark.parametrize("antialias", [True, False])
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
def test_image_correctness_cvcuda(self, size, interpolation, use_max_size, antialias, fn):
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
return

image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8)
actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias)
expected = fn(
F.cvcuda_to_tensor(image), size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias
)

# assert_close will squeeze the batch dimension off the CV-CUDA tensor so we convert ahead of time
actual = F.cvcuda_to_tensor(actual)

atol = 1
if antialias:
# cvcuda.hq_resize is accurate within 9 for the tests
atol = 9
elif interpolation == transforms.InterpolationMode.BICUBIC:
# the CV-CUDA bicubic interpolation differs significantly
atol = 91
assert_close(actual, expected, atol=atol, rtol=0)

def _reference_resize_bounding_boxes(self, bounding_boxes, format, *, size, max_size=None):
old_height, old_width = bounding_boxes.canvas_size
Expand Down Expand Up @@ -972,11 +1034,25 @@ def test_keypoints_correctness(self, size, use_max_size, fn):
@pytest.mark.parametrize("interpolation", set(transforms.InterpolationMode) - set(INTERPOLATION_MODES))
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_pil_interpolation_compat_smoke(self, interpolation, make_input):
input = make_input(self.INPUT_SIZE)

if make_input is make_image_cvcuda and interpolation in {
transforms.InterpolationMode.BOX,
transforms.InterpolationMode.LANCZOS,
}:
pytest.skip("CV-CUDA may support box and lanczos for certain configurations of resize")

with (
contextlib.nullcontext()
if isinstance(input, PIL.Image.Image)
Expand Down Expand Up @@ -1005,6 +1081,9 @@ def test_functional_pil_antialias_warning(self):
make_detection_masks,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_max_size_error(self, size, make_input):
Expand Down Expand Up @@ -1048,6 +1127,9 @@ def test_max_size_error(self, size, make_input):
make_detection_masks,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_resize_size_none(self, input_size, max_size, expected_size, make_input):
Expand All @@ -1058,7 +1140,15 @@ def test_resize_size_none(self, input_size, max_size, expected_size, make_input)
@pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_interpolation_int(self, interpolation, make_input):
input = make_input(self.INPUT_SIZE)
Expand Down Expand Up @@ -1122,6 +1212,9 @@ def test_noop(self, size, make_input):
make_detection_masks,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_no_regression_5405(self, make_input):
Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ class Resize(Transform):

_v1_transform_cls = _transforms.Resize

if CVCUDA_AVAILABLE:
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self,
size: Union[int, Sequence[int], None],
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor


def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
_is_cvcuda_tensor,
),
)
}
Expand Down
77 changes: 77 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ._utils import (
_FillTypeJIT,
_get_cvcuda_interp,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
Expand Down Expand Up @@ -401,6 +402,82 @@ def __resize_image_pil_dispatch(
return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)


_dtype_to_format_cvcuda: dict["cvcuda.Type", "cvcuda.Format"] = {}


def _resize_image_cvcuda(
image: "cvcuda.Tensor",
size: Optional[list[int]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = True,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

if len(_dtype_to_format_cvcuda) == 0:
_dtype_to_format_cvcuda[cvcuda.Type.U8] = cvcuda.Format.U8
_dtype_to_format_cvcuda[cvcuda.Type.U16] = cvcuda.Format.U16
_dtype_to_format_cvcuda[cvcuda.Type.U32] = cvcuda.Format.U32
_dtype_to_format_cvcuda[cvcuda.Type.S8] = cvcuda.Format.S8
_dtype_to_format_cvcuda[cvcuda.Type.S16] = cvcuda.Format.S16
_dtype_to_format_cvcuda[cvcuda.Type.S32] = cvcuda.Format.S32
_dtype_to_format_cvcuda[cvcuda.Type.F32] = cvcuda.Format.F32
_dtype_to_format_cvcuda[cvcuda.Type.F64] = cvcuda.Format.F64

interp = _get_cvcuda_interp(interpolation)
# hamming error for parity to resize_image
if interp == cvcuda.Interp.HAMMING:
raise NotImplementedError("Unsupported interpolation for CV-CUDA resize, got hamming.")

# match the antialias behavior of resize_image
if not (interp == cvcuda.Interp.LINEAR or interp == cvcuda.Interp.CUBIC):
antialias = False

old_height, old_width = image.shape[1], image.shape[2]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)

# No resize needed if dimensions match
if new_height == old_height and new_width == old_width:
return image

# antialias is only supported for cvcuda.hq_resize, if set to true (which is also default)
# we will fast-track to use hq_resize (also matchs the size parameter)
if antialias:
return cvcuda.hq_resize(
image,
out_size=(new_height, new_width),
interpolation=interp,
antialias=antialias,
)

# if not using antialias, we will use cvcuda.resize/pillowresize instead
# resize requires that the shape has the same dimensions as the input
# CV-CUDA tensors are already in NHWC format so we can do a simple tuple creation
shape = image.shape
new_shape = (shape[0], new_height, new_width, shape[3])

# bicubic mode is not accurate when using cvcuda.resize
# cvcuda.pillowresize resolves some of the errors
if interp == cvcuda.Interp.CUBIC:
return cvcuda.pillowresize(
image,
shape=new_shape,
format=_dtype_to_format_cvcuda[image.dtype],
interp=interp,
)

# otherwise we will use cvcuda.resize
return cvcuda.resize(
image,
shape=new_shape,
interp=interp,
)


if CVCUDA_AVAILABLE:
_register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_image_cvcuda)


def resize_mask(mask: torch.Tensor, size: Optional[list[int]], max_size: Optional[int] = None) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
Expand Down
22 changes: 21 additions & 1 deletion torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]:
return get_dimensions_image(video)


def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
# CV-CUDA tensor is always in NHWC layout
# get_dimensions is CHW
return [image.shape[3], image.shape[1], image.shape[2]]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda)


def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting():
return get_num_channels_image(inpt)
Expand Down Expand Up @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels = get_num_channels


def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int:
# CV-CUDA tensor is always in NHWC layout
# get_num_channels is C
return image.shape[3]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda)


def get_size(inpt: torch.Tensor) -> list[int]:
if torch.jit.is_scripting():
return get_size_image(inpt)
Expand Down Expand Up @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:


if CVCUDA_AVAILABLE:
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)
_register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda)


@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
Expand Down
Loading