Skip to content

Commit 9ae382d

Browse files
authored
feat: downsample oversized textures (#260)
* feat: downsample oversided 3d views * typing * one threed check
1 parent 7b56e80 commit 9ae382d

8 files changed

Lines changed: 592 additions & 29 deletions

File tree

src/ndv/models/_data_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def create(cls, data: ArrayT) -> DataWrapper[ArrayT]:
209209
for subclass in sorted(_recurse_subclasses(cls), key=lambda x: x.PRIORITY):
210210
try:
211211
if subclass.supports(data):
212-
logger.debug(f"Using {subclass.__name__} to wrap {type(data)}")
212+
logger.debug("Using %s to wrap %s", subclass.__name__, type(data))
213213
return subclass(data)
214214
except Exception as e:
215215
warnings.warn(

src/ndv/views/_pygfx/_array_canvas.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
from contextlib import suppress
4-
from typing import TYPE_CHECKING, Any, Literal, cast
4+
from functools import lru_cache
5+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
56
from weakref import ReferenceType, WeakValueDictionary, ref
67

78
import cmap as _cmap
@@ -18,6 +19,7 @@
1819
)
1920
from ndv.models._viewer_model import ArrayViewerModel, InteractionMode
2021
from ndv.views._app import filter_mouse_events
22+
from ndv.views._util import downsample_data
2123
from ndv.views.bases import ArrayCanvas, CanvasElement, ImageHandle
2224
from ndv.views.bases._graphics._canvas_elements import RectangularROIHandle, ROIMoveMode
2325

@@ -63,22 +65,33 @@ def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None:
6365
self._render = render
6466
self._grid = cast("Texture", image.geometry.grid)
6567
self._material = cast("ImageBasicMaterial", image.material)
68+
# per-axis downsample strides applied to fit GPU texture limits
69+
self._downsample_factors: tuple[int, ...] = ()
6670

6771
def data(self) -> np.ndarray:
6872
return self._grid.data # type: ignore [no-any-return]
6973

7074
def set_data(self, data: np.ndarray) -> None:
75+
is_three_d = isinstance(self._image, pygfx.Volume)
76+
data, self._downsample_factors = _downcast_and_downsample(
77+
data,
78+
three_d=is_three_d,
79+
warn=False,
80+
copy=False,
81+
)
7182
# If dimensions are unchanged, reuse the buffer
7283
if data.shape == self._grid.data.shape:
7384
self._grid.data[:] = data # pyright: ignore[reportOptionalSubscript]
7485
self._grid.update_range((0, 0, 0), self._grid.size)
7586
# Otherwise, the size (and maybe number of dimensions) changed
7687
# - we need a new buffer
7788
else:
78-
self._grid = pygfx.Texture(data, dim=2)
89+
dim = 3 if is_three_d else 2
90+
self._grid = pygfx.Texture(data, dim=dim)
7991
self._image.geometry = pygfx.Geometry(grid=self._grid)
8092
# RGB images (i.e. 3D datasets) cannot have a colormap
81-
self._material.map = None if self._is_rgb() else self._cmap.to_pygfx()
93+
if not is_three_d:
94+
self._material.map = None if self._is_rgb() else self._cmap.to_pygfx()
8295

8396
def visible(self) -> bool:
8497
return bool(self._image.visible)
@@ -465,11 +478,7 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None:
465478

466479
def add_image(self, data: np.ndarray | None = None) -> PyGFXImageHandle:
467480
"""Add a new Image node to the scene."""
468-
if data is not None:
469-
# pygfx uses a view of the data without copy, so if we don't
470-
# copy it here, the original data will be modified when the
471-
# texture changes.
472-
data = data.copy()
481+
data, downsample_factors = _downcast_and_downsample(data, three_d=False)
473482
tex = pygfx.Texture(data, dim=2)
474483
image = pygfx.Image(
475484
pygfx.Geometry(grid=tex),
@@ -485,15 +494,12 @@ def add_image(self, data: np.ndarray | None = None) -> PyGFXImageHandle:
485494
# FIXME: I suspect there are more performant ways to refresh the canvas
486495
# look into it.
487496
handle = PyGFXImageHandle(image, self.refresh)
497+
handle._downsample_factors = downsample_factors
488498
self._elements[image] = handle
489499
return handle
490500

491501
def add_volume(self, data: np.ndarray | None = None) -> PyGFXImageHandle:
492-
if data is not None:
493-
# pygfx uses a view of the data without copy, so if we don't
494-
# copy it here, the original data will be modified when the
495-
# texture changes.
496-
data = data.copy()
502+
data, downsample_factors = _downcast_and_downsample(data, three_d=True)
497503
tex = pygfx.Texture(data, dim=3)
498504
vol = pygfx.Volume(
499505
pygfx.Geometry(grid=tex),
@@ -512,6 +518,7 @@ def add_volume(self, data: np.ndarray | None = None) -> PyGFXImageHandle:
512518
# FIXME: I suspect there are more performant ways to refresh the canvas
513519
# look into it.
514520
handle = PyGFXImageHandle(vol, self.refresh)
521+
handle._downsample_factors = downsample_factors
515522
self._elements[vol] = handle
516523
return handle
517524

@@ -539,10 +546,23 @@ def set_scales(self, scales: tuple[float, ...]) -> None:
539546
gfx_scales.append(1.0)
540547
sx, sy, sz = gfx_scales[0], gfx_scales[1], gfx_scales[2]
541548
has_visuals = False
542-
for child in self._scene.children:
543-
if isinstance(child, (pygfx.Image, pygfx.Volume)):
544-
child.local.scale = (sx, sy, sz)
545-
has_visuals = True
549+
for handle in self._elements.values():
550+
if not isinstance(handle, PyGFXImageHandle):
551+
continue
552+
child = handle._image
553+
if not isinstance(child, (pygfx.Image, pygfx.Volume)):
554+
continue
555+
_sx, _sy, _sz = sx, sy, sz
556+
# compensate for downsampling so coordinates stay correct
557+
# factors are in data order; pygfx order is (x, y, z) = reversed
558+
factors = handle._downsample_factors
559+
if factors and any(f > 1 for f in factors):
560+
rev = list(reversed(factors))
561+
_sx *= rev[0]
562+
_sy *= rev[1] if len(rev) > 1 else 1
563+
_sz *= rev[2] if len(rev) > 2 else 1
564+
child.local.scale = (_sx, _sy, _sz)
565+
has_visuals = True
546566
if has_visuals:
547567
self.set_range()
548568

@@ -710,3 +730,37 @@ def get_cursor(self, event: MouseMoveEvent) -> CursorType:
710730
if cursor := vis.get_cursor(event):
711731
return cursor
712732
return CursorType.DEFAULT
733+
734+
735+
T = TypeVar("T", bound=np.ndarray | None)
736+
737+
738+
@lru_cache(maxsize=1)
739+
def _get_max_texture_sizes() -> tuple[int | None, int | None]:
740+
"""Return (max_2d, max_3d) texture dimensions from the wgpu adapter."""
741+
try:
742+
import wgpu
743+
744+
adapter = wgpu.gpu.request_adapter_sync()
745+
limits = adapter.limits
746+
max_2d = limits.get("max-texture-dimension-2d")
747+
max_3d = limits.get("max-texture-dimension-3d")
748+
return max_2d, max_3d
749+
except Exception:
750+
return None, None
751+
752+
753+
def _downcast_and_downsample(
754+
data: T, three_d: bool, *, warn: bool = True, copy: bool = True
755+
) -> tuple[T, tuple[int, ...]]:
756+
downsample_factors: tuple[int, ...] = ()
757+
if data is not None:
758+
if copy:
759+
# pygfx uses a view of the data without copy, so if we don't
760+
# copy it here, the original data will be modified when the
761+
# texture changes.
762+
data = data.copy()
763+
maxd = _get_max_texture_sizes()[1 if three_d else 0]
764+
if maxd is not None:
765+
data, downsample_factors = downsample_data(data, maxd, warn=warn) # type: ignore[assignment]
766+
return data, downsample_factors # pyright: ignore[reportReturnType]

src/ndv/views/_util.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Shared utilities for canvas backends."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
7+
import numpy as np
8+
9+
logger = logging.getLogger("ndv")
10+
11+
12+
def downsample_data(
13+
data: np.ndarray, max_size: int, *, warn: bool = True
14+
) -> tuple[np.ndarray, tuple[int, ...]]:
15+
"""Downsample data so no axis exceeds max_size.
16+
17+
Returns the (possibly downsampled view) data and the per-axis stride factors.
18+
"""
19+
factors = tuple(
20+
int(np.ceil(s / max_size)) if s > max_size else 1 for s in data.shape
21+
)
22+
if any(f > 1 for f in factors):
23+
if warn:
24+
logger.warning(
25+
"Data shape %s exceeds max texture dimension (%d) and will be "
26+
"downsampled for rendering (strides: %s).",
27+
data.shape,
28+
max_size,
29+
factors,
30+
)
31+
slices = tuple(slice(None, None, f) for f in factors)
32+
data = data[slices]
33+
return data, factors

src/ndv/views/_vispy/_array_canvas.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import warnings
55
from contextlib import suppress
6-
from typing import TYPE_CHECKING, Any, Literal, cast
6+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
77
from weakref import ReferenceType, WeakValueDictionary
88

99
import cmap as _cmap
@@ -23,6 +23,8 @@
2323
)
2424
from ndv.models._viewer_model import ArrayViewerModel, InteractionMode
2525
from ndv.views._app import filter_mouse_events
26+
from ndv.views._util import downsample_data
27+
from ndv.views._vispy._util import get_max_texture_sizes
2628
from ndv.views.bases import ArrayCanvas
2729
from ndv.views.bases._graphics._canvas_elements import (
2830
CanvasElement,
@@ -43,6 +45,8 @@ class VispyImageHandle(ImageHandle):
4345
def __init__(self, visual: visuals.ImageVisual | visuals.VolumeVisual) -> None:
4446
self._visual = visual
4547
self._allowed_dims = {2, 3} if isinstance(visual, visuals.ImageVisual) else {3}
48+
# per-axis downsample strides applied to fit GPU texture limits
49+
self._downsample_factors: tuple[int, ...] = ()
4650

4751
def data(self) -> np.ndarray:
4852
try:
@@ -58,6 +62,13 @@ def set_data(self, data: np.ndarray) -> None:
5862
stacklevel=2,
5963
)
6064
return
65+
66+
data, downsample_factors = _downcast_and_downsample(
67+
data,
68+
three_d=isinstance(self._visual, visuals.VolumeVisual),
69+
warn=False,
70+
)
71+
self._downsample_factors = downsample_factors
6172
self._visual.set_data(data)
6273

6374
def visible(self) -> bool:
@@ -358,7 +369,7 @@ def refresh(self) -> None:
358369

359370
def add_image(self, data: np.ndarray | None = None) -> VispyImageHandle:
360371
"""Add a new Image node to the scene."""
361-
data = _downcast(data)
372+
data, downsample_factors = _downcast_and_downsample(data, three_d=False)
362373
try:
363374
img = scene.visuals.Image(
364375
data, parent=self._view.scene, texture_format="auto"
@@ -370,13 +381,14 @@ def add_image(self, data: np.ndarray | None = None) -> VispyImageHandle:
370381
img.set_gl_state("additive", depth_test=False)
371382
img.interactive = True
372383
handle = VispyImageHandle(img)
384+
handle._downsample_factors = downsample_factors
373385
self._elements[img] = handle
374386
if data is not None:
375387
self.set_range()
376388
return handle
377389

378390
def add_volume(self, data: np.ndarray | None = None) -> VispyImageHandle:
379-
data = _downcast(data)
391+
data, downsample_factors = _downcast_and_downsample(data, three_d=True)
380392
try:
381393
vol = scene.visuals.Volume(
382394
data,
@@ -393,6 +405,7 @@ def add_volume(self, data: np.ndarray | None = None) -> VispyImageHandle:
393405
vol.set_gl_state("additive", depth_test=False)
394406
vol.interactive = True
395407
handle = VispyImageHandle(vol)
408+
handle._downsample_factors = downsample_factors
396409
self._elements[vol] = handle
397410
if data is not None:
398411
self.set_range()
@@ -418,11 +431,24 @@ def set_scales(self, scales: tuple[float, ...]) -> None:
418431
while len(vis_scales) < 3:
419432
vis_scales.append(1.0)
420433
sx, sy, sz = vis_scales[0], vis_scales[1], vis_scales[2]
421-
for child in self._view.scene.children:
422-
if isinstance(child, (visuals.ImageVisual, visuals.VolumeVisual)):
423-
child.transform = vispy.visuals.transforms.STTransform(
424-
scale=(sx, sy, sz)
425-
)
434+
for handle in self._elements.values():
435+
if not isinstance(handle, VispyImageHandle):
436+
continue
437+
child = handle._visual
438+
if not isinstance(child, (visuals.ImageVisual, visuals.VolumeVisual)):
439+
continue
440+
_sx, _sy, _sz = sx, sy, sz
441+
# compensate for downsampling so coordinates stay correct
442+
# factors are in data order; scene order is (x, y, z) = reversed
443+
factors = handle._downsample_factors
444+
if factors and any(f > 1 for f in factors):
445+
rev = list(reversed(factors))
446+
_sx *= rev[0]
447+
_sy *= rev[1] if len(rev) > 1 else 1
448+
_sz *= rev[2] if len(rev) > 2 else 1
449+
child.transform = vispy.visuals.transforms.STTransform(
450+
scale=(_sx, _sy, _sz)
451+
)
426452
self.set_range()
427453

428454
def set_range(
@@ -561,13 +587,29 @@ def get_cursor(self, event: MouseMoveEvent) -> CursorType:
561587
return CursorType.DEFAULT
562588

563589

564-
def _downcast(data: np.ndarray | None) -> np.ndarray | None:
590+
T = TypeVar("T", bound="np.ndarray | None")
591+
592+
593+
def _downcast(data: T) -> T:
565594
"""Downcast >32bit data to 32bit."""
566595
# downcast to 32bit, preserving int/float
567596
if data is not None:
568597
if np.issubdtype(data.dtype, np.integer) and data.dtype.itemsize > 2:
569598
warnings.warn("Downcasting integer data to uint16.", stacklevel=2)
570-
data = data.astype(np.uint16)
599+
data = data.astype(np.uint16) # type: ignore[assignment]
571600
elif np.issubdtype(data.dtype, np.floating) and data.dtype.itemsize > 4:
572-
data = data.astype(np.float32)
601+
data = data.astype(np.float32) # type: ignore[assignment]
573602
return data
603+
604+
605+
def _downcast_and_downsample(
606+
data: T, three_d: bool, warn: bool = True
607+
) -> tuple[T, tuple[int, ...]]:
608+
"""Downcast >32bit data to 32bit, and downsample GPU texture limits are exceeded."""
609+
data = _downcast(data)
610+
downsample_factors: tuple[int, ...] = ()
611+
if data is not None:
612+
maxd = get_max_texture_sizes()[1 if three_d else 0]
613+
if maxd is not None:
614+
data, downsample_factors = downsample_data(data, maxd, warn=warn) # type: ignore[assignment]
615+
return data, downsample_factors

src/ndv/views/_vispy/_util.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from functools import lru_cache
5+
from typing import TYPE_CHECKING
6+
7+
from vispy.app import Canvas
8+
from vispy.gloo import gl
9+
from vispy.gloo.context import get_current_canvas
10+
11+
if TYPE_CHECKING:
12+
from collections.abc import Generator
13+
14+
15+
@contextmanager
16+
def _opengl_context() -> Generator[None, None, None]:
17+
"""Assure we are running with a valid OpenGL context.
18+
19+
Only create a Canvas is one doesn't exist. Creating and closing a
20+
Canvas causes vispy to process Qt events which can cause problems.
21+
"""
22+
canvas = Canvas(show=False) if get_current_canvas() is None else None
23+
try:
24+
yield
25+
finally:
26+
if canvas is not None:
27+
canvas.close()
28+
29+
30+
@lru_cache
31+
def get_max_texture_sizes() -> tuple[int | None, int | None]:
32+
"""Return the maximum texture sizes for 2D and 3D rendering.
33+
34+
Returns
35+
-------
36+
Tuple[int | None, int | None]
37+
The max textures sizes for (2d, 3d) rendering.
38+
"""
39+
with _opengl_context():
40+
max_size_2d = gl.glGetParameter(gl.GL_MAX_TEXTURE_SIZE)
41+
42+
if not max_size_2d:
43+
max_size_2d = None
44+
45+
# vispy/gloo doesn't provide the GL_MAX_3D_TEXTURE_SIZE location,
46+
# but it can be found in this list of constants
47+
# http://pyopengl.sourceforge.net/documentation/pydoc/OpenGL.GL.html
48+
with _opengl_context():
49+
GL_MAX_3D_TEXTURE_SIZE = 32883
50+
max_size_3d = gl.glGetParameter(GL_MAX_3D_TEXTURE_SIZE)
51+
52+
if not max_size_3d:
53+
max_size_3d = None
54+
55+
return max_size_2d, max_size_3d

0 commit comments

Comments
 (0)